Compare commits

...

81 Commits

Author SHA1 Message Date
comfyanonymous
312d511630 Style fix. (#8390) 2025-06-02 07:22:02 -04:00
Jesse Gonyou
4f4f1c642a Update fix for potential XSS on /view (#8384)
* Update fix for potential XSS on /view

This commit uses mimetypes to add more restricted filetypes to prevent from being served, since mimetypes are what browsers use to determine how to serve files.

* Fix typo

Fixed a typo that prevented the program from running
2025-06-02 06:52:44 -04:00
filtered
010954d277 [BugFix] Update frontend to 1.21.6 (#8383) 2025-06-02 14:57:44 +10:00
filtered
6d46bb4b4c [BugFix] Update frontend to 1.21.5 (#8382) 2025-06-01 16:47:14 -04:00
Christian Byrne
67f57c5bcc [feat] add custom node testing requirement to issue templates (#8374)
Adds mandatory checkbox to bug report and user support templates requiring users to confirm they've tested with custom nodes disabled before submitting issues.
2025-06-01 15:47:07 -04:00
filtered
fd943c928f [BugFix] Update frontend to 1.21.4 (#8377) 2025-06-01 13:57:53 -04:00
ComfyUI Wiki
d3bd983b91 Bump template to 0.1.25 (#8372) 2025-06-01 05:41:17 -04:00
comfyanonymous
fb4754624d Make the casting in lists the same as regular inputs. (#8373) 2025-06-01 05:39:54 -04:00
Benjamin Lu
180db6753f Add Help Menu in NodeLibrarySidebarTab (#8179) 2025-06-01 04:32:32 -04:00
Christian Byrne
d062fcc5c0 [feat] Add ImageStitch node for concatenating images (#8369)
* [feat] Add ImageStitch node for concatenating images with borders

Add ImageStitch node that concatenates images in four directions with optional borders and intelligent size handling. Features include optional second image input, configurable borders with color selection, automatic batch size matching, and dimension alignment via padding or resizing.

Upstreamed from https://github.com/kijai/ComfyUI-KJNodes with enhancements for better error handling and comprehensive test coverage.

* [fix] Fix CI issues with CUDA dependencies and linting

- Mock CUDA-dependent modules in tests to avoid CI failures on CPU-only runners
- Fix ruff linting issues for code style compliance

* [fix] Improve CI compatibility by mocking nodes module import

Prevent CUDA initialization chain by mocking the nodes module at import time,
which is cleaner than deep mocking of CUDA-specific functions.

* [refactor] Clean up ImageStitch tests

- Remove unnecessary sys.path manipulation (pythonpath set in pytest.ini)
- Remove metadata tests that test framework internals rather than functionality
- Rename complex scenario test to be more descriptive of what it tests

* [refactor] Rename 'border' to 'spacing' for semantic accuracy

- Change border_width/border_color to spacing_width/spacing_color in API
- Update all tests to use spacing terminology
- Update comments and variable names throughout
- More accurately describes the gap/separator between images
2025-06-01 04:28:52 -04:00
filtered
456abad834 Update frontend to 1.21 (#8366) 2025-06-01 01:10:04 -04:00
comfyanonymous
19e45e9b0e Make it easier to pass lists of tensors to models. (#8358) 2025-05-31 20:00:20 -04:00
ComfyUI Wiki
97f23b81f3 Bump template to 0.1.23 (#8353)
Correct some error settings in VACE
2025-05-30 23:05:42 -07:00
drhead
08b7cc7506 use fused multiply-add pointwise ops in chroma (#8279) 2025-05-30 18:09:54 -04:00
BennyKok
6c319cbb4e fix: custom comfy-api-base works with subpath (#8332) 2025-05-30 17:51:28 -04:00
Chenlei Hu
df1aebe52e Remove huchenlei from CODEOWNERS (#8350) 2025-05-30 17:27:52 -04:00
comfyanonymous
704fc78854 Put ROCm version in tuple to make it easier to enable stuff based on it. (#8348) 2025-05-30 15:41:02 -04:00
JettHu
1d9fee79fd Add node for regex replace(sub) operation (#8340)
* Add node for regex replace(sub) operation

* Apply suggestions from code review

add tooltips

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>

* Fix indentation

---------

Co-authored-by: Christian Byrne <abolkonsky.rem@gmail.com>
2025-05-30 15:08:59 -04:00
Jedrzej Kosinski
aeba0b3a26 Reduce code duplication for [pro] and [max], rename Pro and Max to [pro] and [max] to be consistent with other BFL nodes, make default seed for Kontext nodes be 1234. since 0 is interpreted by API as 'choose random seed' (#8337) 2025-05-29 17:14:27 -04:00
comfyanonymous
094306b626 ComfyUI version 0.3.39 2025-05-29 14:26:39 -04:00
filtered
31260f0275 Update templates 0.1.22 (#8334) 2025-05-30 03:52:27 +10:00
Robin Huang
f1c9ca816a Add BFL Kontext API Nodes. (#8333)
* Added initial Flux.1 Kontext Pro Image node - recreated branch to save myself sanity from rebase crap after master got rebased

* Add safety filter to Kontext.

* Make safety = 2 and input image is optional.

* Add BFL kontext API nodes.

---------

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
2025-05-29 13:27:40 -04:00
comfyanonymous
f2289a1f59 Delete useless file. (#8327) 2025-05-29 08:29:37 -04:00
Robin Huang
fb83eda287 Revert "Add support for Veo3 API node." (#8322)
This reverts commit 592d056100.
2025-05-29 03:03:11 -04:00
comfyanonymous
5e5e46d40c Not really tested WAN Phantom Support. (#8321) 2025-05-28 23:46:15 -04:00
Yoland Yan
4eba3161cf Refactor Pika API node imports and fix unique_id issue. (#8319)
Added unique_id to hidden parameters and corrected description formatting in PikAdditionsNode.
2025-05-28 23:42:25 -04:00
Robin Huang
592d056100 Add support for Veo3 API node. (#8320) 2025-05-28 23:42:02 -04:00
comfyanonymous
1c1687ab1c Support HiDream SimpleTuner loras. (#8318) 2025-05-28 18:47:15 -04:00
comfyanonymous
e6609dacde ComfyUI version 0.3.38 2025-05-28 02:15:11 -04:00
Christian Byrne
ba37e67964 update frontend patch 1.20.7 (#8312) 2025-05-28 01:42:18 -04:00
comfyanonymous
06c661004e Memory estimation code can now take into account conds. (#8307) 2025-05-27 15:09:05 -04:00
comfyanonymous
c9e1821a7b ComfyUI version 0.3.37 2025-05-27 07:07:44 -04:00
Robin Huang
f58f0f5696 More API nodes: Gemini/Open AI Chat, Tripo, Rodin, Runway Image (#8295)
* Add Ideogram generate node.

* Add staging api.

* Add API_NODE and common error for missing auth token (#5)

* Add Minimax Video Generation + Async Task queue polling example (#6)

* [Minimax] Show video preview and embed workflow in ouput (#7)

* Remove uv.lock

* Remove polling operations.

* Revert "Remove polling operations."

This reverts commit 8415404ce8fbc0262b7de54fc700c5c8854a34fc.

* Update stubs.

* Added Ideogram and Minimax back in.

* Added initial BFL Flux 1.1 [pro] Ultra node (#11)

* Manually add BFL polling status response schema (#15)

* Add function for uploading files. (#18)

* Add Luma nodes (#16)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Refactor util functions (#20)

* Add rest of Luma node functionality (#19)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Fix image_luma_ref not working (#28)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* [Bug] Remove duplicated option T2V-01 in MinimaxTextToVideoNode (#31)

* add veo2, bump av req (#32)

* Add Recraft nodes (#29)

* Add Kling Nodes (#12)

* Add Camera Concepts (luma_concepts) to Luma Video nodes (#33)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Add Runway nodes (#17)

* Convert Minimax node to use VIDEO output type (#34)

* Standard `CATEGORY` system for api nodes (#35)

* Set `Content-Type` header when uploading files (#36)

* add better error propagation to veo2 (#37)

* Add Realistic Image and Logo Raster styles for Recraft v3 (#38)

* Fix runway image upload and progress polling (#39)

* Fix image upload for Luma: only include `Content-Type` header field if it's set explicitly (#40)

* Moved Luma nodes to nodes_luma.py (#47)

* Moved Recraft nodes to nodes_recraft.py (#48)

* Move and fix BFL nodes to node_bfl.py (#49)

* Move and edit Minimax node to nodes_minimax.py (#50)

* Add Recraft Text to Vector node, add Save SVG node to handle its output (#53)

* Added pixverse_template support to Pixverse Text to Video node (#54)

* Added Recraft Controls + Recraft Color RGB nodes (#57)

* split remaining nodes out of nodes_api, make utility lib, refactor ideogram (#61)

* Set request type explicitly (#66)

* Add `control_after_generate` to all seed inputs (#69)

* Fix bug: deleting `Content-Type` when property does not exist (#73)

* Add Pixverse and updated Kling types (#75)

* Added Recraft Style - Infinite Style Library node (#82)

* add ideogram v3 (#83)

* [Kling] Split Camera Control config to its own node (#81)

* Add Pika i2v and t2v nodes (#52)

* Remove Runway nodes (#88)

* Fix: Prompt text can't be validated in Kling nodes when using primitive nodes (#90)

* Update Pika Duration and Resolution options (#94)

* Removed Infinite Style Library until later (#99)

* fix multi image return (#101)

close #96

* Serve SVG files directly (#107)

* Add a bunch of nodes, 3 ready to use, the rest waiting for endpoint support (#108)

* Revert "Serve SVG files directly" (#111)

* Expose 4 remaining Recraft nodes (#112)

* [Kling] Add `Duration` and `Video ID` outputs (#105)

* Add Kling nodes: camera control, start-end frame, lip-sync, video extend (#115)

* Fix error for Recraft ImageToImage error for nonexistent random_seed param (#118)

* Add remaining Pika nodes (#119)

* Make controls input work for Recraft Image to Image node (#120)

* Fix: Nested `AnyUrl` in request model cannot be serialized (Kling, Runway) (#129)

* Show errors and API output URLs to the user (change log levels) (#131)

* Apply small fixes and most prompt validation (if needed to avoid API error) (#135)

* Node name/category modifications (#140)

* Add back Recraft Style - Infinite Style Library node (#141)

* [Kling] Fix: Correct/verify supported subset of input combos in Kling nodes (#149)

* Remove pixverse_template from PixVerse Transition Video node (#155)

* Use 3.9 compat syntax (#164)

* Handle Comfy API key based authorizaton (#167)

Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>

* [BFL] Print download URL of successful task result directly on nodes (#175)

* Show output URL and progress text on Pika nodes (#168)

* [Ideogram] Print download URL of successful task result directly on nodes (#176)

* [Kling] Print download URL of successful task result directly on nodes (#181)

* Merge upstream may 14 25 (#186)

Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com>
Co-authored-by: AustinMroz <austinmroz@utexas.edu>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Benjamin Lu <benceruleanlu@proton.me>
Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com>
Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com>
Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com>
Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com>
Co-authored-by: guill <guill@users.noreply.github.com>
Co-authored-by: Chenlei Hu <hcl@comfy.org>
Co-authored-by: Terry Jia <terryjia88@gmail.com>
Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
Co-authored-by: liesen <liesen.dev@gmail.com>
Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Robin Huang <robin.j.huang@gmail.com>
Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>

* Update instructions on how to develop API Nodes. (#171)

* Add Runway FLF and I2V nodes (#187)

* Add OpenAI chat node (#188)

* Update README.

* Add Google Gemini API node (#191)

* Add Runway Gen 4 Text to Image Node (#193)

* [Runway, Gemini] Update node display names and attributes (#194)

* Update path from "image-to-video" to "image_to_video" (#197)

* [Runway] Split I2V nodes into separate gen3 and gen4 nodes (#198)

* Update runway i2v ratio enum (#201)

* Rodin3D: implement Rodin3D API Nodes (#190)

Co-authored-by: WhiteGiven <c15838568211@163.com>
Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Add Tripo Nodes. (#189)

Co-authored-by: Robin Huang <robin.j.huang@gmail.com>

* Change casing of categories "3D"  => "3d" (#208)

* [tripo] fix negtive_prompt and mv2model (#212)

* [tripo] set default param to None (#215)

* Add description and tooltip to Tripo Refine model. (#218)

* Update.

* Fix rebase errors.

* Fix rebase errors.

* Update templates.

* Bump frontend.

* Add file type info for file inputs.

---------

Co-authored-by: Christian Byrne <cbyrne@comfy.org>
Co-authored-by: Jedrzej Kosinski <kosinkadink1@gmail.com>
Co-authored-by: Chenlei Hu <hcl@comfy.org>
Co-authored-by: thot experiment <94414189+thot-experiment@users.noreply.github.com>
Co-authored-by: comfyanonymous <comfyanonymous@protonmail.com>
Co-authored-by: AustinMroz <austinmroz@utexas.edu>
Co-authored-by: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Co-authored-by: Benjamin Lu <benceruleanlu@proton.me>
Co-authored-by: Andrew Kvochko <kvochko@users.noreply.github.com>
Co-authored-by: Pam <42671363+pamparamm@users.noreply.github.com>
Co-authored-by: chaObserv <154517000+chaObserv@users.noreply.github.com>
Co-authored-by: Yoland Yan <4950057+yoland68@users.noreply.github.com>
Co-authored-by: guill <guill@users.noreply.github.com>
Co-authored-by: Terry Jia <terryjia88@gmail.com>
Co-authored-by: Silver <65376327+silveroxides@users.noreply.github.com>
Co-authored-by: catboxanon <122327233+catboxanon@users.noreply.github.com>
Co-authored-by: liesen <liesen.dev@gmail.com>
Co-authored-by: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Co-authored-by: blepping <157360029+blepping@users.noreply.github.com>
Co-authored-by: Changrz <51637999+WhiteGiven@users.noreply.github.com>
Co-authored-by: WhiteGiven <c15838568211@163.com>
Co-authored-by: seed93 <liangding1990@163.com>
2025-05-27 03:00:58 -04:00
filtered
3a10b9641c [BugFix] Update frontend to 1.20.6 (#8296) 2025-05-27 02:47:06 -04:00
comfyanonymous
89a84e32d2 Disable initial GPU load when novram is used. (#8294) 2025-05-26 16:39:27 -04:00
comfyanonymous
e5799c4899 Enable pytorch attention by default on AMD gfx1151 (#8282) 2025-05-26 04:29:25 -04:00
comfyanonymous
a0651359d7 Return proper error if diffusion model not detected properly. (#8272) 2025-05-25 05:28:11 -04:00
comfyanonymous
ad3bd8aa49 ComfyUI version 0.3.36 2025-05-24 17:30:37 -04:00
comfyanonymous
5a87757ef9 Better error if sageattention is installed but a dependency is missing. (#8264) 2025-05-24 06:43:12 -04:00
Christian Byrne
464aece92b update frontend package to v1.20.5 (#8260) 2025-05-23 21:53:49 -07:00
comfyanonymous
0b50d4c0db Add argument to explicitly enable fp8 compute support. (#8257)
This can be used to test if your current GPU/pytorch version supports fp8 matrix mult in combination with --fast or the fp8_e4m3fn_fast dtype.
2025-05-23 17:43:50 -04:00
drhead
30b2eb8a93 create arange on-device (#8255) 2025-05-23 16:15:06 -04:00
comfyanonymous
f85c08df06 Make VACE conditionings stackable. (#8240) 2025-05-22 19:22:26 -04:00
comfyanonymous
4202e956a0 Add append feature to conditioning_set_values (#8239)
Refactor unclipconditioning node.
2025-05-22 08:11:13 -04:00
Terry Jia
b838c36720 remove mtl from 3d model file list (#8192) 2025-05-22 08:08:36 -04:00
Chenlei Hu
fc39184ea9 Update frontend to 1.20 (#8232) 2025-05-22 02:24:36 -04:00
ComfyUI Wiki
ded60c33a0 Update templates to 0.1.18 (#8224) 2025-05-21 11:40:08 -07:00
Michael Abrahams
8bb858e4d3 Improve performance with large number of queued prompts (#8176)
* get_current_queue_volatile

* restore get_current_queue method

* remove extra import
2025-05-21 05:14:17 -04:00
编程界的小学生
57893c843f Code Optimization and Issues Fixes in ComfyUI server (#8196)
* Update server.py

* Update server.py
2025-05-21 04:59:42 -04:00
Jedrzej Kosinski
65da29aaa9 Make torch.compile LoRA/key-compatible (#8213)
* Make torch compile node use wrapper instead of object_patch for the entire diffusion_models object, allowing key assotiations on diffusion_models to not break (loras, getting attributes, etc.)

* Moved torch compile code into comfy_api so it can be used by custom nodes with a degree of confidence

* Refactor set_torch_compile_wrapper to support a list of keys instead of just diffusion_model, as well as additional torch.compile args

* remove unused import

* Moved torch compile kwargs to be stored in model_options instead of attachments; attachments are more intended for things to be 'persisted', AKA not deepcopied

* Add some comments

* Remove random line of code, not sure how it got there
2025-05-21 04:56:56 -04:00
comfyanonymous
10024a38ea ComfyUI version v0.3.35 2025-05-21 04:50:37 -04:00
comfyanonymous
87f9130778 Revert "This doesn't seem to be needed on chroma. (#8209)" (#8210)
This reverts commit 7e84bf5373.
2025-05-20 05:39:55 -04:00
comfyanonymous
7e84bf5373 This doesn't seem to be needed on chroma. (#8209) 2025-05-20 05:29:23 -04:00
filtered
4f3b50ba51 Update README ROCm text to match link (#8199)
- Follow-up on #8198
2025-05-19 16:40:55 -04:00
comfyanonymous
e930a387d6 Update AMD instructions in README. (#8198) 2025-05-19 04:58:41 -04:00
comfyanonymous
d8e5662822 Remove default delimiter. (#8183) 2025-05-18 04:12:12 -04:00
LaVie024
3d44a09812 Update nodes_string.py (#8173) 2025-05-18 04:11:11 -04:00
comfyanonymous
62690eddec Node to add pixel space noise to an image. (#8182) 2025-05-18 04:09:56 -04:00
Christian Byrne
05eb10b43a Validate video inputs (#8133)
* validate kling lip sync input video

* add tooltips

* update duration estimates

* decrease epsilon

* fix rebase error
2025-05-18 04:08:47 -04:00
Silver
f5e4e976f4 Add missing category for T5TokenizerOption (#8177)
Change it if you need to but it should at least have a category.
2025-05-18 02:59:06 -04:00
comfyanonymous
aee2908d03 Remove useless log. (#8166) 2025-05-17 06:27:34 -04:00
comfyanonymous
dc46db7aa4 Make ImagePadForOutpaint return a 3 channel mask. (#8157) 2025-05-16 15:15:55 -04:00
filtered
7046983d95 Remove Desktop versioning claim from README (#8155) 2025-05-16 10:45:36 -07:00
comfyanonymous
1c2d45d2b5 Fix typo in last PR. (#8144)
More robust model detection for future proofing.
2025-05-15 19:02:19 -04:00
George0726
c820ef950d Add Wan-FUN Camera Control models and Add WanCameraImageToVideo node (#8013)
* support wan camera models

* fix by ruff check

* change camera_condition type; make camera_condition optional

* support camera trajectory nodes

* fix camera direction

---------

Co-authored-by: Qirui Sun <sunqr0667@126.com>
2025-05-15 19:00:43 -04:00
comfyanonymous
6a2e4bb9e0 Remove old hack used to fix windows pytorch 2.4 on the portable. (#8139)
Not necessary anymore.
2025-05-15 08:21:47 -04:00
Christian Byrne
f1f9763b4c Add get_duration method to Comfy VIDEO type (#8122)
* get duration from VIDEO type

* video get_duration unit test

* fix Windows unit test: can't delete opened temp file
2025-05-15 00:11:41 -04:00
comfyanonymous
08368f8e00 Update comment on ROCm pytorch attention in README. (#8123) 2025-05-14 17:54:50 -04:00
Christian Byrne
f3ff5c40db don't retry if API returns task failure (#8111) 2025-05-14 01:28:30 -04:00
Christian Byrne
98ff01e148 Display progress and result URL directly on API nodes (#8102)
* [Luma] Print download URL of successful task result directly on nodes (#177)

[Veo] Print download URL of successful task result directly on nodes (#184)

[Recraft] Print download URL of successful task result directly on nodes (#183)

[Pixverse] Print download URL of successful task result directly on nodes (#182)

[Kling] Print download URL of successful task result directly on nodes (#181)

[MiniMax] Print progress text and download URL of successful task result directly on nodes (#179)

[Docs] Link to docs in `API_NODE` class property type annotation comment (#178)

[Ideogram] Print download URL of successful task result directly on nodes (#176)

[Kling] Print download URL of successful task result directly on nodes (#181)

[Veo] Print download URL of successful task result directly on nodes (#184)

[Recraft] Print download URL of successful task result directly on nodes (#183)

[Pixverse] Print download URL of successful task result directly on nodes (#182)

[MiniMax] Print progress text and download URL of successful task result directly on nodes (#179)

[Docs] Link to docs in `API_NODE` class property type annotation comment (#178)

[Luma] Print download URL of successful task result directly on nodes (#177)

[Ideogram] Print download URL of successful task result directly on nodes (#176)

Show output URL and progress text on Pika nodes (#168)

[BFL] Print download URL of successful task result directly on nodes (#175)

[OpenAI ] Print download URL of successful task result directly on nodes (#174)

* fix ruff errors

* fix 3.10 syntax error
2025-05-14 00:33:18 -04:00
thot experiment
bab836d88d rework client.py to be more robust, add logging of api requests (#7988)
* rework how errors are handled on the client side

* add logging to /temp

* fix ruff

* fix rebase, stupid vscode gui
2025-05-13 20:42:29 -04:00
comfyanonymous
4a9014e201 Hunyuan Custom initial untested implementation. (#8101) 2025-05-13 15:53:47 -04:00
thot experiment
8a7c894d54 fix negative momentum (#8100) 2025-05-13 10:50:32 -07:00
comfyanonymous
a814f2e8cc Fix issue with old pytorch RMSNorm. (#8095) 2025-05-13 07:54:28 -04:00
comfyanonymous
481732a0ed Support official ACE Step loras. (#8094) 2025-05-13 07:32:16 -04:00
Christian Byrne
2156ce9453 add comment about using api key in headless (#8082) 2025-05-12 23:06:44 -04:00
thot experiment
4136502b7a implement APG guidance (#8081)
* first pass at impementing AGP

* rename, cleanup code

* fix ruff

* fix modified cond to match ref impl better, support different cond arity
2025-05-12 21:10:24 -04:00
Terry Jia
9ad287ff20 add support to record video as output for 3d node (#7927)
* add support to record video as output for 3d node

* source format

* add support to record video for load3d animation node
2025-05-12 16:47:14 -04:00
Chenlei Hu
f5cacaeb14 Update frontend to v1.19 (#8076)
* Update frontend to v1.19

* Update requirements.txt
2025-05-12 16:47:02 -04:00
Terry Jia
b7ed5f57bd string node (#7952) 2025-05-12 16:29:32 -04:00
thot experiment
b4abca828e add opus and mp3 to audio output node (#8019)
* first pass at opus and mp3 as well as migrating flac to pyav

* minor mp3 encoding fix

* fix ruff

* delete dead code

* split out save audio to separate nodes per filetype

* fix ruff
2025-05-12 16:00:01 -04:00
74 changed files with 9482 additions and 3582 deletions

View File

@@ -15,6 +15,14 @@ body:
steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Expected Behavior

View File

@@ -11,6 +11,14 @@ body:
**2:** You have made an effort to find public answers to your question before asking here. In other words, you googled it first, and scrolled through recent help topics.
If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first.
- type: checkboxes
id: custom-nodes-test
attributes:
label: Custom Node Testing
description: Please confirm you have tried to reproduce the issue with all custom nodes disabled.
options:
- label: I have tried disabling custom nodes and the issue persists (see [how to disable custom nodes](https://docs.comfy.org/troubleshooting/custom-node-issues#step-1%3A-test-with-all-custom-nodes-disabled) if you need help)
required: true
- type: textarea
attributes:
label: Your question

View File

@@ -5,20 +5,20 @@
# Inlined the team members for now.
# Maintainers
*.md @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
*.md @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/tests-unit/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/notebooks/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/script_examples/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/.github/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/requirements.txt @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
/pyproject.toml @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @Kosinkadink @christian-byrne
# Python web server
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @christian-byrne
/api_server/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/app/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
/utils/ @yoland68 @robinjhuang @webfiltered @pythongosssss @ltdrdata @christian-byrne
# Node developers
/comfy_extras/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @huchenlei @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy_extras/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne
/comfy/comfy_types/ @yoland68 @robinjhuang @pythongosssss @ltdrdata @Kosinkadink @webfiltered @christian-byrne

View File

@@ -110,7 +110,6 @@ ComfyUI follows a weekly release cycle every Friday, with three interconnected r
2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
- Builds a new release using the latest stable core version
- Version numbers match the core release (e.g., Desktop v1.7.0 uses Core v1.7.0)
3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
- Weekly frontend updates are merged into the core repository
@@ -198,11 +197,11 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2.4```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
This is the command to install the nightly with ROCm 6.3 which might have some performance improvements:
This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.3```
```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
### Intel GPUs (Windows and Linux)
@@ -302,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips
You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```

View File

@@ -205,6 +205,19 @@ comfyui-workflow-templates is not installed.
""".strip()
)
@classmethod
def embedded_docs_path(cls) -> str:
"""Get the path to embedded documentation"""
try:
import comfyui_embedded_docs
return str(
importlib.resources.files(comfyui_embedded_docs) / "docs"
)
except ImportError:
logging.info("comfyui-embedded-docs package not found")
return None
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""

View File

@@ -88,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"

View File

@@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod

View File

@@ -24,6 +24,10 @@ class CONDRegular:
conds.append(x.cond)
return torch.cat(conds)
def size(self):
return list(self.cond.size())
class CONDNoiseShape(CONDRegular):
def process_cond(self, batch_size, device, area, **kwargs):
data = self.cond
@@ -64,6 +68,7 @@ class CONDCrossAttn(CONDRegular):
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
@@ -78,3 +83,48 @@ class CONDConstant(CONDRegular):
def concat(self, others):
return self.cond
def size(self):
return [1]
class CONDList(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)
def can_concat(self, other):
if len(self.cond) != len(other.cond):
return False
for i in range(len(self.cond)):
if self.cond[i].shape != other.cond[i].shape:
return False
return True
def concat(self, others):
out = []
for i in range(len(self.cond)):
o = [self.cond[i]]
for x in others:
o.append(x.cond[i])
out.append(torch.cat(o))
return out
def size(self): # hackish implementation to make the mem estimation work
o = 0
c = 1
for c in self.cond:
size = c.size()
o += math.prod(size)
if len(size) > 1:
c = size[1]
return [1, c, o // c]

View File

@@ -80,15 +80,13 @@ class DoubleStreamBlock(nn.Module):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img))
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt))
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -102,12 +100,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn))
img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img))))
# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn))
txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt))))
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -152,7 +150,7 @@ class SingleStreamBlock(nn.Module):
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x))
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -162,7 +160,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
x.addcmul_(mod.gate, output)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -178,6 +176,6 @@ class LastLayer(nn.Module):
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = torch.addcmul(shift[:, None, :], 1 + scale[:, None, :], self.norm_final(x))
x = self.linear(x)
return x

View File

@@ -163,7 +163,7 @@ class Chroma(nn.Module):
distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
# get all modulation index
modulation_index = timestep_embedding(torch.arange(mod_index_length), 32).to(img.device, img.dtype)
modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
# we need to broadcast the modulation index here so each batch has all of the index
modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
# and we need to broadcast timestep and guidance along too

View File

@@ -228,6 +228,7 @@ class HunyuanVideo(nn.Module):
y: Tensor,
guidance: Tensor = None,
guiding_frame_index=None,
ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
@@ -238,6 +239,14 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
img = torch.cat([ref_latent, img], dim=-2)
ref_latent_ids[..., 0] = -1
ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
if guiding_frame_index is not None:
token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
@@ -313,6 +322,8 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add
img = img[:, : img_len]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
@@ -324,7 +335,7 @@ class HunyuanVideo(nn.Module):
img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, control=None, transformer_options={}, **kwargs):
def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -334,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, control, transformer_options)
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out

View File

@@ -20,8 +20,11 @@ if model_management.xformers_enabled():
if model_management.sage_attention_enabled():
try:
from sageattention import sageattn
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
except ModuleNotFoundError as e:
if e.name == "sageattention":
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
else:
raise e
exit(-1)
if model_management.flash_attention_enabled():

View File

@@ -247,6 +247,60 @@ class VaceWanAttentionBlock(WanAttentionBlock):
return c_skip, c
class WanCamAdapter(nn.Module):
def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
super(WanCamAdapter, self).__init__()
# Pixel Unshuffle: reduce spatial dimensions by a factor of 8
self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
# Convolution: reduce spatial dimensions by a factor
# of 2 (without overlap)
self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
# Residual blocks for feature extraction
self.residual_blocks = nn.Sequential(
*[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
)
def forward(self, x):
# Reshape to merge the frame dimension into batch
bs, c, f, h, w = x.size()
x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
# Pixel Unshuffle operation
x_unshuffled = self.pixel_unshuffle(x)
# Convolution operation
x_conv = self.conv(x_unshuffled)
# Feature extraction with residual blocks
out = self.residual_blocks(x_conv)
# Reshape to restore original bf dimension
out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
# Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
out = out.permute(0, 2, 1, 3, 4)
return out
class WanCamResidualBlock(nn.Module):
def __init__(self, dim, operation_settings={}):
super(WanCamResidualBlock, self).__init__()
self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.relu = nn.ReLU(inplace=True)
self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
def forward(self, x):
residual = x
out = self.relu(self.conv1(x))
out = self.conv2(out)
out += residual
return out
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
@@ -485,13 +539,20 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
@@ -581,7 +642,7 @@ class VaceWanModel(WanModel):
t,
context,
vace_context,
vace_strength=1.0,
vace_strength,
clip_fea=None,
freqs=None,
transformer_options={},
@@ -607,8 +668,11 @@ class VaceWanModel(WanModel):
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
orig_shape = list(vace_context.shape)
vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
c = c.flatten(2).transpose(1, 2)
c = list(c.split(orig_shape[0], dim=0))
# arguments
x_orig = x
@@ -628,8 +692,9 @@ class VaceWanModel(WanModel):
ii = self.vace_layers_mapping.get(i, None)
if ii is not None:
c_skip, c = self.vace_blocks[ii](c, x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength
for iii in range(len(c)):
c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
x += c_skip * vace_strength[iii]
del c_skip
# head
x = self.head(x, e)
@@ -637,3 +702,92 @@ class VaceWanModel(WanModel):
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x
class CameraWanModel(WanModel):
r"""
Wan diffusion backbone supporting both text-to-video and image-to-video.
"""
def __init__(self,
model_type='camera',
patch_size=(1, 2, 2),
text_len=512,
in_dim=16,
dim=2048,
ffn_dim=8192,
freq_dim=256,
text_dim=4096,
out_dim=16,
num_heads=16,
num_layers=32,
window_size=(-1, -1),
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
flf_pos_embed_token_number=None,
image_model=None,
in_dim_control_adapter=24,
device=None,
dtype=None,
operations=None,
):
super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
def forward_orig(
self,
x,
t,
context,
clip_fea=None,
freqs=None,
camera_conditions = None,
transformer_options={},
**kwargs,
):
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if self.control_adapter is not None and camera_conditions is not None:
x_camera = self.control_adapter(camera_conditions).to(x.dtype)
x = x + x_camera
grid_sizes = x.shape[2:]
x = x.flatten(2).transpose(1, 2)
# time embeddings
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
e0 = self.time_projection(e).unflatten(1, (6, self.dim))
# context
context = self.text_embedding(context)
context_img_len = None
if clip_fea is not None:
if self.img_emb is not None:
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
context_img_len = clip_fea.shape[-2]
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
# head
x = self.head(x, e)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return x

View File

@@ -283,8 +283,15 @@ def model_lora_keys_unet(model, key_map={}):
for k in sdk:
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
key_map["transformer.{}".format(key_lora)] = k #SimpleTuner regular format
if isinstance(model, comfy.model_base.ACEStep):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@@ -102,6 +102,13 @@ def model_sampling(model_config, model_type):
return ModelSampling(model_config)
def convert_tensor(extra, dtype):
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
return extra
class BaseModel(torch.nn.Module):
def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_model=UNetModel):
super().__init__()
@@ -135,6 +142,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.memory_usage_factor_conds = ()
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -164,9 +172,14 @@ class BaseModel(torch.nn.Module):
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
if hasattr(extra, "dtype"):
if extra.dtype != torch.int and extra.dtype != torch.long:
extra = extra.to(dtype)
extra = convert_tensor(extra, dtype)
elif isinstance(extra, list):
ex = []
for ext in extra:
ex.append(convert_tensor(ext, dtype))
extra = ex
extra_conds[o] = extra
t = self.process_timestep(t, x=x, **extra_conds)
@@ -325,19 +338,28 @@ class BaseModel(torch.nn.Module):
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
def memory_required(self, input_shape):
def memory_required(self, input_shape, cond_shapes={}):
input_shapes = [input_shape]
for c in self.memory_usage_factor_conds:
shape = cond_shapes.get(c, None)
if shape is not None and len(shape) > 0:
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)
def extra_conds_shapes(self, **kwargs):
return {}
def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
adm_inputs = []
@@ -924,6 +946,10 @@ class HunyuanVideo(BaseModel):
if guiding_frame_index is not None:
out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
ref_latent = kwargs.get("ref_latent", None)
if ref_latent is not None:
out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
return out
def scale_latent_inpaint(self, latent_image, **kwargs):
@@ -1043,6 +1069,11 @@ class WAN21(BaseModel):
clip_vision_output = kwargs.get("clip_vision_output", None)
if clip_vision_output is not None:
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
time_dim_concat = kwargs.get("time_dim_concat", None)
if time_dim_concat is not None:
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
return out
@@ -1058,23 +1089,39 @@ class WAN21_Vace(WAN21):
vace_frames = kwargs.get("vace_frames", None)
if vace_frames is None:
noise_shape[1] = 32
vace_frames = torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)
for i in range(0, vace_frames.shape[1], 16):
vace_frames = vace_frames.clone()
vace_frames[:, i:i + 16] = self.process_latent_in(vace_frames[:, i:i + 16])
vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
mask = kwargs.get("vace_mask", None)
if mask is None:
noise_shape[1] = 64
mask = torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)
mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
out['vace_context'] = comfy.conds.CONDRegular(torch.cat([vace_frames.to(noise), mask.to(noise)], dim=1))
vace_frames_out = []
for j in range(len(vace_frames)):
vf = vace_frames[j].clone()
for i in range(0, vf.shape[1], 16):
vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
vf = torch.cat([vf, mask[j]], dim=1)
vace_frames_out.append(vf)
vace_strength = kwargs.get("vace_strength", 1.0)
vace_frames = torch.stack(vace_frames_out, dim=1)
out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out
class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
camera_conditions = kwargs.get("camera_conditions", None)
if camera_conditions is not None:
out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
return out
class Hunyuan3Dv2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):

View File

@@ -361,6 +361,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "vace"
dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "camera"
else:
if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "i2v"
@@ -618,6 +620,9 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
if "conv_in.weight" not in state_dict:
return None
match = {}
transformer_depth = []

View File

@@ -297,11 +297,16 @@ except:
try:
if is_amd():
try:
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
except:
rocm_version = (6, -1)
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
logging.info("AMD arch: {}".format(arch))
logging.info("ROCm version: {}".format(rocm_version))
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx1100", "gfx1101"]): # TODO: more arches
if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
ENABLE_PYTORCH_ATTENTION = True
except:
pass
@@ -695,7 +700,7 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY:
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
model_size = dtype_size(dtype) * parameters
@@ -1257,6 +1262,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
if args.supports_fp8_compute:
return True
if not is_nvidia():
return False

View File

@@ -30,7 +30,7 @@ if RMSNorm is None:
def __init__(
self,
normalized_shape,
eps=None,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import uuid
import math
import collections
import comfy.model_management
import comfy.conds
import comfy.utils
@@ -104,6 +106,21 @@ def cleanup_additional_models(models):
if hasattr(m, 'cleanup'):
m.cleanup()
def estimate_memory(model, noise_shape, conds):
cond_shapes = collections.defaultdict(list)
cond_shapes_min = {}
for _, cs in conds.items():
for cond in cs:
for k, v in model.model.extra_conds_shapes(**cond).items():
cond_shapes[k].append(v)
if cond_shapes_min.get(k, None) is None:
cond_shapes_min[k] = [v]
elif math.prod(v) > math.prod(cond_shapes_min[k][0]):
cond_shapes_min[k] = [v]
memory_required = model.model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]), cond_shapes=cond_shapes)
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
@@ -117,9 +134,8 @@ def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=Non
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
memory_required = model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) + inference_memory
minimum_memory_required = model.memory_required([noise_shape[0]] + list(noise_shape[1:])) + inference_memory
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required)
memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory)
real_model = model.model
return real_model, conds, models

View File

@@ -256,7 +256,13 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
for i in range(1, len(to_batch_temp) + 1):
batch_amount = to_batch_temp[:len(to_batch_temp)//i]
input_shape = [len(batch_amount) * first_shape[0]] + list(first_shape)[1:]
if model.memory_required(input_shape) * 1.5 < free_memory:
cond_shapes = collections.defaultdict(list)
for tt in batch_amount:
cond = {k: v.size() for k, v in to_run[tt][0].conditioning.items()}
for k, v in to_run[tt][0].conditioning.items():
cond_shapes[k].append(v.size())
if model.memory_required(input_shape, cond_shapes=cond_shapes) * 1.5 < free_memory:
to_batch = batch_amount
break

View File

@@ -992,6 +992,16 @@ class WAN21_FunControl2V(WAN21_T2V):
out = model_base.WAN21(self, image_to_video=False, device=device)
return out
class WAN21_Camera(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "camera",
"in_dim": 32,
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
return out
class WAN21_Vace(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@@ -1129,6 +1139,6 @@ class ACEStep(supported_models_base.BASE):
def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]

View File

@@ -1,25 +0,0 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@@ -78,8 +78,6 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
if "global_step" in pl_sd:
logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:

View File

@@ -43,3 +43,13 @@ class VideoInput(ABC):
components = self.get_components()
return components.images.shape[2], components.images.shape[1]
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
components = self.get_components()
frame_count = components.images.shape[0]
return float(frame_count / components.frame_rate)

View File

@@ -80,6 +80,38 @@ class VideoFromFile(VideoInput):
return stream.width, stream.height
raise ValueError(f"No video stream found in file '{self.__file}'")
def get_duration(self) -> float:
"""
Returns the duration of the video in seconds.
Returns:
Duration in seconds
"""
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
if container.duration is not None:
return float(container.duration / av.time_base)
# Fallback: calculate from frame count and frame rate
video_stream = next(
(s for s in container.streams if s.type == "video"), None
)
if video_stream and video_stream.frames and video_stream.average_rate:
return float(video_stream.frames / video_stream.average_rate)
# Last resort: decode frames to count them
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
for packet in container.demux(video_stream):
for _ in packet.decode():
frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
raise ValueError(f"Could not determine duration for file '{self.__file}'")
def get_components_internal(self, container: InputContainer) -> VideoComponents:
# Get video frames
frames = []

View File

@@ -0,0 +1,5 @@
from .torch_compile import set_torch_compile_wrapper
__all__ = [
"set_torch_compile_wrapper",
]

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
import torch
import comfy.utils
from comfy.patcher_extension import WrappersMP
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.patcher_extension import WrapperExecutor
COMPILE_KEY = "torch.compile"
TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
'''
Create a wrapper that will refer to the compiled_diffusion_model.
'''
def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
try:
orig_modules = {}
for key, value in compiled_module_dict.items():
orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
comfy.utils.set_attr(executor.class_obj, key, value)
return executor(*args, **kwargs)
finally:
for key, value in orig_modules.items():
comfy.utils.set_attr(executor.class_obj, key, value)
return apply_torch_compile_wrapper
def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
keys: list[str]=["diffusion_model"], *args, **kwargs):
'''
Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
When a list of keys is provided, it will perform torch.compile on only the selected modules.
'''
# clear out any other torch.compile wrappers
model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
# if no keys, default to 'diffusion_model'
if not keys:
keys = ["diffusion_model"]
# create kwargs dict that can be referenced later
compile_kwargs = {
"backend": backend,
"options": options,
"mode": mode,
"fullgraph": fullgraph,
"dynamic": dynamic,
}
# get a dict of compiled keys
compiled_modules = {}
for key in keys:
compiled_modules[key] = torch.compile(
model=model.get_model_object(key),
**compile_kwargs,
)
# add torch.compile wrapper
wrapper_func = apply_torch_compile_factory(
compiled_module_dict=compiled_modules,
)
# store wrapper to run on BaseModel's apply_model function
model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
# keep compile kwargs for reference
model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs

View File

@@ -18,6 +18,8 @@ Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to
python run main.py --comfy-api-base https://stagingapi.comfy.org
```
To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
### Redocly Instructions
@@ -28,7 +30,7 @@ When developing locally, use the `redocly-dev.yaml` file to generate pydantic mo
Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
```bash
# Download the OpenAPI file from prod server.
# Download the OpenAPI file from staging server.
curl -o openapi.yaml https://stagingapi.comfy.org/openapi
# Filter out unneeded API definitions.
@@ -39,3 +41,25 @@ redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_no
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```
# Merging to Master
Before merging to comfyanonymous/ComfyUI master, follow these steps:
1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
1. Make sure the ComfyUI API is deployed to prod with your changes.
1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
```bash
# Download the OpenAPI file from prod server.
curl -o openapi.yaml https://api.comfy.org/openapi
# Filter out unneeded API definitions.
npm install -g @redocly/cli
redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
# Generate the pydantic datamodels for validation.
datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
```

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import io
import logging
from typing import Optional
import mimetypes
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
@@ -15,6 +16,7 @@ from comfy_api_nodes.apis.client import (
UploadRequest,
UploadResponse,
)
from server import PromptServer
import numpy as np
@@ -60,7 +62,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s
def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Args:
@@ -94,6 +98,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
img = Image.open(io.BytesIO(img_data))
elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
@@ -207,6 +215,7 @@ def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
image_bytesio = download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)
def process_image_response(response: requests.Response) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
@@ -311,11 +320,27 @@ def tensor_to_data_uri(
return f"data:{mime_type};base64,{base64_string}"
def text_filepath_to_base64_string(filepath: str) -> str:
"""Converts a text file to a base64 string."""
with open(filepath, "rb") as f:
file_content = f.read()
return base64.b64encode(file_content).decode("utf-8")
def text_filepath_to_data_uri(filepath: str) -> str:
"""Converts a text file to a data URI."""
base64_string = text_filepath_to_base64_string(filepath)
mime_type, _ = mimetypes.guess_type(filepath)
if mime_type is None:
mime_type = "application/octet-stream"
return f"data:{mime_type};base64,{base64_string}"
def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Uploads a single file to ComfyUI API and returns its download URL.
@@ -350,9 +375,33 @@ def upload_file_to_comfyapi(
return response.download_url
def video_to_base64_string(
video: VideoInput,
container_format: VideoContainer = None,
codec: VideoCodec = None
) -> str:
"""
Converts a video input to a base64 string.
Args:
video: The video input to convert
container_format: Optional container format to use (defaults to video.container if available)
codec: Optional codec to use (defaults to video.codec if available)
"""
video_bytes_io = io.BytesIO()
# Use provided format/codec if specified, otherwise use video's own if available
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
video_bytes_io.seek(0)
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
codec: VideoCodec = VideoCodec.H264,
max_duration: Optional[int] = None,
@@ -454,7 +503,7 @@ def audio_ndarray_to_bytesio(
def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
codec_name: str = "aac",
mime_type: str = "audio/mp4",
@@ -481,8 +530,25 @@ def upload_audio_to_comfyapi(
return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
def audio_to_base64_string(
audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
) -> str:
"""Converts an audio input to a base64 string."""
sample_rate: int = audio["sample_rate"]
waveform: torch.Tensor = audio["waveform"]
audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
audio_bytes_io = audio_ndarray_to_bytesio(
audio_data_np, sample_rate, container_format, codec_name
)
audio_bytes = audio_bytes_io.getvalue()
return base64.b64encode(audio_bytes).decode("utf-8")
def upload_images_to_comfyapi(
image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str,str]] = None, mime_type: Optional[str] = None
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
mime_type: Optional[str] = None,
) -> list[str]:
"""
Uploads images to ComfyUI API and returns download URLs.
@@ -547,17 +613,24 @@ def upload_images_to_comfyapi(
return download_urls
def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
upscale_method="nearest-exact", crop="disabled",
allow_gradient=True, add_channel_dim=False):
def resize_mask_to_image(
mask: torch.Tensor,
image: torch.Tensor,
upscale_method="nearest-exact",
crop="disabled",
allow_gradient=True,
add_channel_dim=False,
):
"""
Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
"""
_, H, W, _ = image.shape
mask = mask.unsqueeze(-1)
mask = mask.movedim(-1,1)
mask = common_upscale(mask, width=W, height=H, upscale_method=upscale_method, crop=crop)
mask = mask.movedim(1,-1)
mask = mask.movedim(-1, 1)
mask = common_upscale(
mask, width=W, height=H, upscale_method=upscale_method, crop=crop
)
mask = mask.movedim(1, -1)
if not add_channel_dim:
mask = mask.squeeze(-1)
if not allow_gradient:
@@ -565,12 +638,41 @@ def resize_mask_to_image(mask: torch.Tensor, image: torch.Tensor,
return mask
def validate_string(string: str, strip_whitespace=True, field_name="prompt", min_length=None, max_length=None):
def validate_string(
string: str,
strip_whitespace=True,
field_name="prompt",
min_length=None,
max_length=None,
):
if string is None:
raise Exception(f"Field '{field_name}' cannot be empty.")
if strip_whitespace:
string = string.strip()
if min_length and len(string) < min_length:
raise Exception(f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long.")
raise Exception(
f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
)
if max_length and len(string) > max_length:
raise Exception(f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long.")
if not string:
raise Exception(f"Field '{field_name}' cannot be empty.")
raise Exception(
f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
)
def image_tensor_pair_to_batch(
image1: torch.Tensor, image2: torch.Tensor
) -> torch.Tensor:
"""
Converts a pair of image tensors to a batch tensor.
If the images are not the same size, the smaller image is resized to
match the larger image.
"""
if image1.shape[1:] != image2.shape[1:]:
image2 = common_upscale(
image2.movedim(-1, 1),
image1.shape[2],
image1.shape[1],
"bilinear",
"center",
).movedim(1, -1)
return torch.cat((image1, image2), dim=0)

File diff suppressed because it is too large Load Diff

View File

@@ -108,6 +108,24 @@ class BFLFluxProGenerateRequest(BaseModel):
# )
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(

View File

@@ -94,15 +94,19 @@ from __future__ import annotations
import logging
import time
import io
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable
import socket
from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
from enum import Enum
import json
import requests
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs
from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
T = TypeVar("T", bound=BaseModel)
R = TypeVar("R", bound=BaseModel)
@@ -111,6 +115,21 @@ P = TypeVar("P", bound=BaseModel) # For poll response
PROGRESS_BAR_MAX = 100
class NetworkError(Exception):
"""Base exception for network-related errors with diagnostic information."""
pass
class LocalNetworkError(NetworkError):
"""Exception raised when local network connectivity issues are detected."""
pass
class ApiServerError(NetworkError):
"""Exception raised when the API server is unreachable but internet is working."""
pass
class EmptyRequest(BaseModel):
"""Base class for empty request bodies.
For GET requests, fields will be sent as query parameters."""
@@ -120,7 +139,7 @@ class EmptyRequest(BaseModel):
class UploadRequest(BaseModel):
file_name: str = Field(..., description="Filename to upload")
content_type: str | None = Field(
content_type: Optional[str] = Field(
None,
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
)
@@ -141,7 +160,7 @@ class HttpMethod(str, Enum):
class ApiClient:
"""
Client for making HTTP requests to an API with authentication and error handling.
Client for making HTTP requests to an API with authentication, error handling, and retry logic.
"""
def __init__(
@@ -151,12 +170,26 @@ class ApiClient:
comfy_api_key: Optional[str] = None,
timeout: float = 3600.0,
verify_ssl: bool = True,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
retry_status_codes: Optional[Tuple[int, ...]] = None,
):
self.base_url = base_url
self.auth_token = auth_token
self.comfy_api_key = comfy_api_key
self.timeout = timeout
self.verify_ssl = verify_ssl
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
# Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
# 500, 502, 503, 504 (Server Errors)
self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
def _generate_operation_id(self, path: str) -> str:
"""Generates a unique operation ID for logging."""
return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
def _create_json_payload_args(
self,
@@ -211,6 +244,56 @@ class ApiClient:
return headers
def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
"""
Check connectivity to determine if network issues are local or server-related.
Args:
target_url: URL to check connectivity to
Returns:
Dictionary with connectivity status details
"""
results = {
"internet_accessible": False,
"api_accessible": False,
"is_local_issue": False,
"is_api_issue": False
}
# First check basic internet connectivity using a reliable external site
try:
# Use a reliable external domain for checking basic connectivity
check_response = requests.get("https://www.google.com",
timeout=5.0,
verify=self.verify_ssl)
if check_response.status_code < 500:
results["internet_accessible"] = True
except (requests.RequestException, socket.error):
results["internet_accessible"] = False
results["is_local_issue"] = True
return results
# Now check API server connectivity
try:
# Extract domain from the target URL to do a simpler health check
parsed_url = urlparse(target_url)
api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
# Try to reach the API domain
api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
if api_response.status_code < 500:
results["api_accessible"] = True
else:
results["api_accessible"] = False
results["is_api_issue"] = True
except requests.RequestException:
results["api_accessible"] = False
# If we can reach the internet but not the API, it's an API issue
results["is_api_issue"] = True
return results
def request(
self,
method: str,
@@ -221,9 +304,10 @@ class ApiClient:
headers: Optional[Dict[str, str]] = None,
content_type: str = "application/json",
multipart_parser: Callable = None,
retry_count: int = 0, # Used internally for tracking retries
) -> Dict[str, Any]:
"""
Make an HTTP request to the API
Make an HTTP request to the API with automatic retries for transient errors.
Args:
method: HTTP method (GET, POST, etc.)
@@ -233,14 +317,19 @@ class ApiClient:
files: Files to upload
headers: Additional headers
content_type: Content type of the request. Defaults to application/json.
retry_count: Internal parameter for tracking retries, do not set manually
Returns:
Parsed JSON response
Raises:
requests.RequestException: If the request fails
LocalNetworkError: If local network connectivity issues are detected
ApiServerError: If the API server is unreachable but internet is working
Exception: For other request failures
"""
url = urljoin(self.base_url, path)
# Use urljoin but ensure path is relative to avoid absolute path behavior
relative_path = path.lstrip('/')
url = urljoin(self.base_url, relative_path)
self.check_auth(self.auth_token, self.comfy_api_key)
# Combine default headers with any provided headers
request_headers = self.get_headers()
@@ -265,6 +354,16 @@ class ApiClient:
else:
payload_args = self._create_json_payload_args(data, request_headers)
operation_id = self._generate_operation_id(path)
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
request_headers=request_headers,
request_params=params,
request_data=data if content_type == "application/json" else "[form-data or other]"
)
try:
response = requests.request(
method=method,
@@ -275,50 +374,228 @@ class ApiClient:
**payload_args,
)
# Check if we should retry based on status code
if (response.status_code in self.retry_status_codes and
retry_count < self.max_retries):
# Calculate delay with exponential backoff
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request failed with status {response.status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# Raise exception for error status codes
response.raise_for_status()
except requests.ConnectionError:
raise Exception(
f"Unable to connect to the API server at {self.base_url}. Please check your internet connection or verify the service is available."
# Log successful response
response_content_to_log = response.content
try:
# Attempt to parse JSON for prettier logging, fallback to raw content
response_content_to_log = response.json()
except json.JSONDecodeError:
pass # Keep as bytes/str if not JSON
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, # Pass request details again for context in log
request_url=url,
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content=response_content_to_log
)
except requests.Timeout:
raise Exception(
f"Request timed out after {self.timeout} seconds. The server might be experiencing high load or the operation is taking longer than expected."
except requests.ConnectionError as e:
error_message = f"ConnectionError: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
request_url=url,
error_message=error_message
)
# Only perform connectivity check if we've exhausted all retries
if retry_count >= self.max_retries:
# Check connectivity to determine if it's a local or API issue
connectivity = self._check_connectivity(self.base_url)
if connectivity["is_local_issue"]:
raise LocalNetworkError(
"Unable to connect to the API server due to local network issues. "
"Please check your internet connection and try again."
) from e
elif connectivity["is_api_issue"]:
raise ApiServerError(
f"The API server at {self.base_url} is currently unreachable. "
f"The service may be experiencing issues. Please try again later."
) from e
# If we haven't exhausted retries yet, retry the request
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Connection error: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
# If we've exhausted retries and didn't identify the specific issue,
# raise a generic exception
final_error_message = (
f"Unable to connect to the API server after {self.max_retries} attempts. "
f"Please check your internet connection or try again later."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.Timeout as e:
error_message = f"Timeout: {str(e)}"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, request_url=url,
error_message=error_message
)
# Retry timeouts if we haven't exhausted retries
if retry_count < self.max_retries:
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"Request timed out. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
final_error_message = (
f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
f"The server might be experiencing high load or the operation is taking longer than expected."
)
request_logger.log_request_response( # Log final failure
operation_id=operation_id,
request_method=method, request_url=url,
error_message=final_error_message
)
raise Exception(final_error_message) from e
except requests.HTTPError as e:
status_code = e.response.status_code if hasattr(e, "response") else None
error_message = f"HTTP Error: {str(e)}"
original_error_message = f"HTTP Error: {str(e)}"
error_content_for_log = None
if hasattr(e, "response") and e.response is not None:
error_content_for_log = e.response.content
try:
error_content_for_log = e.response.json()
except json.JSONDecodeError:
pass
# Try to extract detailed error message from JSON response for user display
# but log the full error content.
user_display_error_message = original_error_message
# Try to extract detailed error message from JSON response
try:
if hasattr(e, "response") and e.response.content:
if hasattr(e, "response") and e.response is not None and e.response.content:
error_json = e.response.json()
if "error" in error_json and "message" in error_json["error"]:
error_message = f"API Error: {error_json['error']['message']}"
user_display_error_message = f"API Error: {error_json['error']['message']}"
if "type" in error_json["error"]:
error_message += f" (Type: {error_json['error']['type']})"
user_display_error_message += f" (Type: {error_json['error']['type']})"
elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
user_display_error_message = f"API Error: {json.dumps(error_json)}"
else: # Non-dict JSON error
user_display_error_message = f"API Error: {str(error_json)}"
except json.JSONDecodeError:
# If not JSON, use the raw content if it's not too long, or a summary
if hasattr(e, "response") and e.response is not None and e.response.content:
raw_content = e.response.content.decode(errors='ignore')
if len(raw_content) < 200: # Arbitrary limit for display
user_display_error_message = f"API Error (raw): {raw_content}"
else:
error_message = f"API Error: {error_json}"
except Exception as json_error:
# If we can't parse the JSON, fall back to the original error message
logging.debug(
f"[DEBUG] Failed to parse error response: {str(json_error)}"
user_display_error_message = f"API Error (raw, status {status_code})"
request_logger.log_request_response(
operation_id=operation_id,
request_method=method, request_url=url,
response_status_code=status_code,
response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
response_content=error_content_for_log,
error_message=original_error_message # Log the original exception string as error
)
logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response is not None and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
# Retry if the status code is in our retry list and we haven't exhausted retries
if (status_code in self.retry_status_codes and
retry_count < self.max_retries):
delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
logging.warning(
f"HTTP error {status_code}. "
f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
)
time.sleep(delay)
return self.request(
method=method,
path=path,
params=params,
data=data,
files=files,
headers=headers,
content_type=content_type,
multipart_parser=multipart_parser,
retry_count=retry_count + 1,
)
logging.debug(f"[DEBUG] API Error: {error_message} (Status: {status_code})")
if hasattr(e, "response") and e.response.content:
logging.debug(f"[DEBUG] Response content: {e.response.content}")
# Specific error messages for common status codes for user display
if status_code == 401:
error_message = "Unauthorized: Please login first to use this node."
if status_code == 402:
error_message = "Payment Required: Please add credits to your account to use this node."
if status_code == 409:
error_message = "There is a problem with your account. Please contact support@comfy.org. "
if status_code == 429:
error_message = "Rate Limit Exceeded: Please try again later."
raise Exception(error_message)
user_display_error_message = "Unauthorized: Please login first to use this node."
elif status_code == 402:
user_display_error_message = "Payment Required: Please add credits to your account to use this node."
elif status_code == 409:
user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
elif status_code == 429:
user_display_error_message = "Rate Limit Exceeded: Please try again later."
# else, user_display_error_message remains as parsed from response or original HTTPError string
raise Exception(user_display_error_message) # Raise with the user-friendly message
# Parse and return JSON response
if response.content:
@@ -336,26 +613,126 @@ class ApiClient:
upload_url: str,
file: io.BytesIO | str,
content_type: str | None = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
):
"""Upload a file to the API. Make sure the file has a filename equal to what the url expects.
"""Upload a file to the API with retry logic.
Args:
upload_url: The URL to upload to
file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
mime_type: Optional mime type to set for the upload
content_type: Optional mime type to set for the upload
max_retries: Maximum number of retry attempts
retry_delay: Initial delay between retries in seconds
retry_backoff_factor: Multiplier for the delay after each retry
"""
headers = {}
if content_type:
headers["Content-Type"] = content_type
# Prepare the file data
if isinstance(file, io.BytesIO):
file.seek(0) # Ensure we're at the start of the file
data = file.read()
return requests.put(upload_url, data=data, headers=headers)
elif isinstance(file, str):
with open(file, "rb") as f:
data = f.read()
return requests.put(upload_url, data=data, headers=headers)
else:
raise ValueError("File must be either a BytesIO object or a file path string")
# Try the upload with retries
last_exception = None
operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
# Log initial attempt (without full file data for brevity)
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT",
request_url=upload_url,
request_headers=headers,
request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
)
for retry_attempt in range(max_retries + 1):
try:
response = requests.put(upload_url, data=data, headers=headers)
response.raise_for_status()
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url, # For context
response_status_code=response.status_code,
response_headers=dict(response.headers),
response_content="File uploaded successfully." # Or response.text if available
)
return response
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
last_exception = e
error_message_for_log = f"{type(e).__name__}: {str(e)}"
response_content_for_log = None
status_code_for_log = None
headers_for_log = None
if hasattr(e, 'response') and e.response is not None:
status_code_for_log = e.response.status_code
headers_for_log = dict(e.response.headers)
try:
response_content_for_log = e.response.json()
except json.JSONDecodeError:
response_content_for_log = e.response.content
request_logger.log_request_response(
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
response_status_code=status_code_for_log,
response_headers=headers_for_log,
response_content=response_content_for_log,
error_message=error_message_for_log
)
if retry_attempt < max_retries:
delay = retry_delay * (retry_backoff_factor ** retry_attempt)
logging.warning(
f"File upload failed: {str(e)}. "
f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
)
time.sleep(delay)
else:
break # Max retries reached
# If we've exhausted all retries, determine the final error type and raise
final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
try:
# Check basic internet connectivity
check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
if check_response.status_code >= 500: # Google itself has an issue (rare)
final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
f"(status {check_response.status_code}). Original error: {str(last_exception)}")
# Not raising LocalNetworkError here as Google itself might be down.
# If Google is reachable, the issue is likely with the upload server or a more specific local problem
# not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
# The original last_exception is probably most relevant.
except (requests.RequestException, socket.error) as conn_check_exc:
# Could not reach Google, likely a local network issue
final_error_message = (f"Failed to upload file due to network connectivity issues "
f"(cannot reach Google: {str(conn_check_exc)}). "
f"Original upload error: {str(last_exception)}")
request_logger.log_request_response( # Log final failure reason
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise LocalNetworkError(final_error_message) from last_exception
request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
operation_id=operation_id,
request_method="PUT", request_url=upload_url,
error_message=final_error_message
)
raise Exception(final_error_message) from last_exception
class ApiEndpoint(Generic[T, R]):
@@ -403,6 +780,9 @@ class SynchronousOperation(Generic[T, R]):
verify_ssl: bool = True,
content_type: str = "application/json",
multipart_parser: Callable = None,
max_retries: int = 3,
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
):
self.endpoint = endpoint
self.request = request
@@ -419,8 +799,12 @@ class SynchronousOperation(Generic[T, R]):
self.files = files
self.content_type = content_type
self.multipart_parser = multipart_parser
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
def execute(self, client: Optional[ApiClient] = None) -> R:
"""Execute the API operation using the provided client or create one"""
"""Execute the API operation using the provided client or create one with retry support"""
try:
# Create client if not provided
if client is None:
@@ -430,6 +814,9 @@ class SynchronousOperation(Generic[T, R]):
comfy_api_key=self.comfy_api_key,
timeout=self.timeout,
verify_ssl=self.verify_ssl,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
# Convert request model to dict, but use None for EmptyRequest
@@ -443,11 +830,6 @@ class SynchronousOperation(Generic[T, R]):
if isinstance(value, Enum):
request_dict[key] = value.value
if request_dict:
for key, value in request_dict.items():
if isinstance(value, Enum):
request_dict[key] = value.value
# Debug log for request
logging.debug(
f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
@@ -455,7 +837,7 @@ class SynchronousOperation(Generic[T, R]):
logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
# Make the request
# Make the request with built-in retry
resp = client.request(
method=self.endpoint.method.value,
path=self.endpoint.path,
@@ -476,8 +858,18 @@ class SynchronousOperation(Generic[T, R]):
# Parse and return the response
return self._parse_response(resp)
except LocalNetworkError as e:
# Propagate specific network error types
logging.error(f"[ERROR] Local network error: {str(e)}")
raise
except ApiServerError as e:
# Propagate API server errors
logging.error(f"[ERROR] API server error: {str(e)}")
raise
except Exception as e:
logging.error(f"[DEBUG] API Exception: {str(e)}")
logging.error(f"[ERROR] API Exception: {str(e)}")
raise Exception(str(e))
def _parse_response(self, resp):
@@ -511,12 +903,19 @@ class PollingOperation(Generic[T, R]):
failed_statuses: list,
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] = None,
result_url_extractor: Callable[[R], str] = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
comfy_api_key: Optional[str] = None,
auth_kwargs: Optional[Dict[str,str]] = None,
poll_interval: float = 5.0,
max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
):
self.poll_endpoint = poll_endpoint
self.request = request
@@ -527,12 +926,19 @@ class PollingOperation(Generic[T, R]):
self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
self.poll_interval = poll_interval
self.max_poll_attempts = max_poll_attempts
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration
# Polling configuration
self.status_extractor = status_extractor or (
lambda x: getattr(x, "status", None)
)
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses
@@ -548,11 +954,46 @@ class PollingOperation(Generic[T, R]):
base_url=self.api_base,
auth_token=self.auth_token,
comfy_api_key=self.comfy_api_key,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
retry_backoff_factor=self.retry_backoff_factor,
)
return self._poll_until_complete(client)
except LocalNetworkError as e:
# Provide clear message for local network issues
raise Exception(
f"Polling failed due to local network issues. Please check your internet connection. "
f"Details: {str(e)}"
) from e
except ApiServerError as e:
# Provide clear message for API server issues
raise Exception(
f"Polling failed due to API server issues. The service may be experiencing problems. "
f"Please try again later. Details: {str(e)}"
) from e
except Exception as e:
raise Exception(f"Error during polling: {str(e)}")
def _display_text_on_node(self, text: str):
"""Sends text to the client which will be displayed on the node in the UI"""
if not self.node_id:
return
PromptServer.instance.send_progress_text(text, self.node_id)
def _display_time_progress_on_node(self, time_completed: int):
if not self.node_id:
return
if self.estimated_duration is not None:
estimated_time_remaining = max(
0, int(self.estimated_duration) - int(time_completed)
)
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
else:
message = f"Task in progress: {time_completed:.0f}s"
self._display_text_on_node(message)
def _check_task_status(self, response: R) -> TaskStatus:
"""Check task status using the status extractor function"""
try:
@@ -569,10 +1010,13 @@ class PollingOperation(Generic[T, R]):
def _poll_until_complete(self, client: ApiClient) -> R:
"""Poll until the task is complete"""
poll_count = 0
consecutive_errors = 0
max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
if self.progress_extractor:
progress = utils.ProgressBar(PROGRESS_BAR_MAX)
while True:
while poll_count < self.max_poll_attempts:
try:
poll_count += 1
logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
@@ -599,8 +1043,12 @@ class PollingOperation(Generic[T, R]):
data=request_dict,
)
# Successfully got a response, reset consecutive error count
consecutive_errors = 0
# Parse response
response_obj = self.poll_endpoint.response_model.model_validate(resp)
# Check if task is complete
status = self._check_task_status(response_obj)
logging.debug(f"[DEBUG] Task Status: {status}")
@@ -612,7 +1060,15 @@ class PollingOperation(Generic[T, R]):
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
if status == TaskStatus.COMPLETED:
logging.debug("[DEBUG] Task completed successfully")
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
else:
message = "Task completed successfully!"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
@@ -628,8 +1084,43 @@ class PollingOperation(Generic[T, R]):
logging.debug(
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
)
for i in range(int(self.poll_interval)):
time_completed = (poll_count * self.poll_interval) + i
self._display_time_progress_on_node(time_completed)
time.sleep(1)
except (LocalNetworkError, ApiServerError) as e:
# For network-related errors, increment error count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
) from e
# Log the error but continue polling
logging.warning(
f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval)
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
logging.error(f"[DEBUG] Polling error: {str(e)}")
raise Exception(f"Error while polling: {str(e)}")
logging.warning(
f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
f"Will retry in {self.poll_interval} seconds."
)
time.sleep(self.poll_interval)
# If we've exhausted all polling attempts
raise Exception(
f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
f"The operation may still be running on the server but is taking longer than expected."
)

View File

@@ -0,0 +1,125 @@
import os
import datetime
import json
import logging
import folder_paths
# Get the logger instance
logger = logging.getLogger(__name__)
def get_log_directory():
"""
Ensures the API log directory exists within ComfyUI's temp directory
and returns its path.
"""
base_temp_dir = folder_paths.get_temp_directory()
log_dir = os.path.join(base_temp_dir, "api_logs")
try:
os.makedirs(log_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating API log directory {log_dir}: {e}")
# Fallback to base temp directory if sub-directory creation fails
return base_temp_dir
return log_dir
def _format_data_for_logging(data):
"""Helper to format data (dict, str, bytes) for logging."""
if isinstance(data, bytes):
try:
return data.decode('utf-8') # Try to decode as text
except UnicodeDecodeError:
return f"[Binary data of length {len(data)} bytes]"
elif isinstance(data, (dict, list)):
try:
return json.dumps(data, indent=2, ensure_ascii=False)
except TypeError:
return str(data) # Fallback for non-serializable objects
return str(data)
def log_request_response(
operation_id: str,
request_method: str,
request_url: str,
request_headers: dict | None = None,
request_params: dict | None = None,
request_data: any = None,
response_status_code: int | None = None,
response_headers: dict | None = None,
response_content: any = None,
error_message: str | None = None
):
"""
Logs API request and response details to a file in the temp/api_logs directory.
"""
log_dir = get_log_directory()
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
filepath = os.path.join(log_dir, filename)
log_content = []
log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
log_content.append(f"Operation ID: {operation_id}")
log_content.append("-" * 30 + " REQUEST " + "-" * 30)
log_content.append(f"Method: {request_method}")
log_content.append(f"URL: {request_url}")
if request_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
if request_params:
log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
if request_data:
log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
if response_status_code is not None:
log_content.append(f"Status Code: {response_status_code}")
if response_headers:
log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
if response_content:
log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
if error_message:
log_content.append(f"Error:\n{error_message}")
try:
with open(filepath, "w", encoding="utf-8") as f:
f.write("\n".join(log_content))
logger.debug(f"API log saved to: {filepath}")
except Exception as e:
logger.error(f"Error writing API log to {filepath}: {e}")
if __name__ == '__main__':
# Example usage (for testing the logger directly)
logger.setLevel(logging.DEBUG)
# Mock folder_paths for direct execution if not running within ComfyUI full context
if not hasattr(folder_paths, 'get_temp_directory'):
class MockFolderPaths:
def get_temp_directory(self):
# Create a local temp dir for testing if needed
p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
os.makedirs(p, exist_ok=True)
return p
folder_paths = MockFolderPaths()
log_request_response(
operation_id="test_operation_get",
request_method="GET",
request_url="https://api.example.com/test",
request_headers={"Authorization": "Bearer testtoken"},
request_params={"param1": "value1"},
response_status_code=200,
response_content={"message": "Success!"}
)
log_request_response(
operation_id="test_operation_post_error",
request_method="POST",
request_url="https://api.example.com/submit",
request_data={"key": "value", "nested": {"num": 123}},
error_message="Connection timed out"
)
log_request_response(
operation_id="test_binary_response",
request_method="GET",
request_url="https://api.example.com/image.png",
response_status_code=200,
response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
)

View File

@@ -0,0 +1,57 @@
from __future__ import annotations
from enum import Enum
from typing import Optional, List
from pydantic import BaseModel, Field
class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST")
subscription_key: str = Field(..., description="subscription key")
class Rodin3DGenerateResponse(BaseModel):
message: Optional[str] = Field(None, description="Return message.")
prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
submit_time: Optional[str] = Field(None, description="Submit Time")
uuid: Optional[str] = Field(None, description="Task str")
jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
class JobStatus(str, Enum):
"""
Status for jobs
"""
Done = "Done"
Failed = "Failed"
Generating = "Generating"
Waiting = "Waiting"
class Rodin3DCheckStatusRequest(BaseModel):
subscription_key: str = Field(..., description="subscription from generate endpoint")
class JobItem(BaseModel):
uuid: str = Field(..., description="uuid")
status: JobStatus = Field(...,description="Status Currently")
class Rodin3DCheckStatusResponse(BaseModel):
jobs: List[JobItem] = Field(..., description="Job status List")
class Rodin3DDownloadRequest(BaseModel):
task_uuid: str = Field(..., description="Task str")
class RodinResourceItem(BaseModel):
url: str = Field(..., description="Download Url")
name: str = Field(..., description="File name with ext")
class Rodin3DDownloadResponse(BaseModel):
list: List[RodinResourceItem] = Field(..., description="Source List")

View File

@@ -0,0 +1,275 @@
from __future__ import annotations
from comfy_api_nodes.apis import (
TripoModelVersion,
TripoTextureQuality,
)
from enum import Enum
from typing import Optional, List, Dict, Any, Union
from pydantic import BaseModel, Field, RootModel
class TripoStyle(str, Enum):
PERSON_TO_CARTOON = "person:person2cartoon"
ANIMAL_VENOM = "animal:venom"
OBJECT_CLAY = "object:clay"
OBJECT_STEAMPUNK = "object:steampunk"
OBJECT_CHRISTMAS = "object:christmas"
OBJECT_BARBIE = "object:barbie"
GOLD = "gold"
ANCIENT_BRONZE = "ancient_bronze"
NONE = "None"
class TripoTaskType(str, Enum):
TEXT_TO_MODEL = "text_to_model"
IMAGE_TO_MODEL = "image_to_model"
MULTIVIEW_TO_MODEL = "multiview_to_model"
TEXTURE_MODEL = "texture_model"
REFINE_MODEL = "refine_model"
ANIMATE_PRERIGCHECK = "animate_prerigcheck"
ANIMATE_RIG = "animate_rig"
ANIMATE_RETARGET = "animate_retarget"
STYLIZE_MODEL = "stylize_model"
CONVERT_MODEL = "convert_model"
class TripoTextureAlignment(str, Enum):
ORIGINAL_IMAGE = "original_image"
GEOMETRY = "geometry"
class TripoOrientation(str, Enum):
ALIGN_IMAGE = "align_image"
DEFAULT = "default"
class TripoOutFormat(str, Enum):
GLB = "glb"
FBX = "fbx"
class TripoTopology(str, Enum):
BIP = "bip"
QUAD = "quad"
class TripoSpec(str, Enum):
MIXAMO = "mixamo"
TRIPO = "tripo"
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
CLIMB = "preset:climb"
JUMP = "preset:jump"
RUN = "preset:run"
SLASH = "preset:slash"
SHOOT = "preset:shoot"
HURT = "preset:hurt"
FALL = "preset:fall"
TURN = "preset:turn"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
VOXEL = "voxel"
VORONOI = "voronoi"
MINECRAFT = "minecraft"
class TripoConvertFormat(str, Enum):
GLTF = "GLTF"
USDZ = "USDZ"
FBX = "FBX"
OBJ = "OBJ"
STL = "STL"
_3MF = "3MF"
class TripoTextureFormat(str, Enum):
BMP = "BMP"
DPX = "DPX"
HDR = "HDR"
JPEG = "JPEG"
OPEN_EXR = "OPEN_EXR"
PNG = "PNG"
TARGA = "TARGA"
TIFF = "TIFF"
WEBP = "WEBP"
class TripoTaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
SUCCESS = "success"
FAILED = "failed"
CANCELLED = "cancelled"
UNKNOWN = "unknown"
BANNED = "banned"
EXPIRED = "expired"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
file_token: str
class TripoUrlReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
url: str
class TripoObjectStorage(BaseModel):
bucket: str
key: str
class TripoObjectReference(BaseModel):
type: str
object: TripoObjectStorage
class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
image_seed: Optional[int] = Field(None, description='The seed for the text')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoImageToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
class TripoRefineModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
class TripoAnimatePrerigcheckRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
class TripoAnimateRigRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
class TripoAnimateRetargetRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
animation: TripoAnimation = Field(..., description='The animation to apply')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
class TripoStylizeModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
original_model_task_id: str = Field(..., description='The task ID of the original model')
block_size: Optional[int] = Field(80, description='The block size for stylization')
class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(4096, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
class TripoTaskRequest(RootModel):
root: Union[
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimatePrerigcheckRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoStylizeModelRequest,
TripoConvertModelRequest
]
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoTask = Field(..., description='The task data')
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: Dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')
frozen: float = Field(..., description='The frozen balance')
class TripoBalanceResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoBalanceData = Field(..., description='The balance data')
class TripoErrorResponse(BaseModel):
code: int = Field(..., description='The error code')
message: str = Field(..., description='The error message')
suggestion: str = Field(..., description='The suggestion for fixing the error')

View File

@@ -1,5 +1,6 @@
import io
from inspect import cleandoc
from typing import Union, Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
@@ -8,6 +9,7 @@ from comfy_api_nodes.apis.bfl_api import (
BFLFluxCannyImageRequest,
BFLFluxDepthImageRequest,
BFLFluxProGenerateRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProUltraGenerateRequest,
BFLFluxProGenerateResponse,
)
@@ -30,6 +32,7 @@ import requests
import torch
import base64
import time
from server import PromptServer
def convert_mask_to_image(mask: torch.Tensor):
@@ -42,14 +45,19 @@ def convert_mask_to_image(mask: torch.Tensor):
def handle_bfl_synchronous_operation(
operation: SynchronousOperation, timeout_bfl_calls=360
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)
def _poll_until_generated(polling_url: str, timeout=360):
def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
@@ -61,11 +69,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)
response = requests.Session().send(request.prepare())
if response.status_code == 200:
result = response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
img_response = requests.get(img_url)
return process_image_response(img_response)
elif result["status"] in [
@@ -180,6 +198,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -212,6 +231,7 @@ class FluxProUltraImageNode(ComfyNodeABC):
seed=0,
image_prompt=None,
image_prompt_strength=0.1,
unique_id: Union[str, None] = None,
**kwargs,
):
if image_prompt is None:
@@ -246,10 +266,162 @@ class FluxProUltraImageNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxKontextProImageNode(ComfyNodeABC):
"""
Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.
"""
MINIMUM_RATIO = 1 / 4
MAXIMUM_RATIO = 4 / 1
MINIMUM_RATIO_STR = "1:4"
MAXIMUM_RATIO_STR = "4:1"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Prompt for the image generation - specify what and how to edit.",
},
),
"aspect_ratio": (
IO.STRING,
{
"default": "16:9",
"tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
},
),
"guidance": (
IO.FLOAT,
{
"default": 3.0,
"min": 0.1,
"max": 99.0,
"step": 0.1,
"tooltip": "Guidance strength for the image generation process"
},
),
"steps": (
IO.INT,
{
"default": 50,
"min": 1,
"max": 150,
"tooltip": "Number of steps for the image generation process"
},
),
"seed": (
IO.INT,
{
"default": 1234,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "The random seed used for creating the noise.",
},
),
"prompt_upsampling": (
IO.BOOLEAN,
{
"default": False,
"tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
},
),
},
"optional": {
"input_image": (IO.IMAGE,),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, aspect_ratio: str):
try:
validate_aspect_ratio(
aspect_ratio,
minimum_ratio=cls.MINIMUM_RATIO,
maximum_ratio=cls.MAXIMUM_RATIO,
minimum_ratio_str=cls.MINIMUM_RATIO_STR,
maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
)
except Exception as e:
return str(e)
return True
RETURN_TYPES = (IO.IMAGE,)
DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
FUNCTION = "api_call"
API_NODE = True
CATEGORY = "api node/image/BFL"
BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate"
def api_call(
self,
prompt: str,
aspect_ratio: str,
guidance: float,
steps: int,
input_image: Optional[torch.Tensor]=None,
seed=0,
prompt_upsampling=False,
unique_id: Union[str, None] = None,
**kwargs,
):
if input_image is None:
validate_string(prompt, strip_whitespace=False)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=self.BFL_PATH,
method=HttpMethod.POST,
request_model=BFLFluxKontextProGenerateRequest,
response_model=BFLFluxProGenerateResponse,
),
request=BFLFluxKontextProGenerateRequest(
prompt=prompt,
prompt_upsampling=prompt_upsampling,
guidance=round(guidance, 1),
steps=steps,
seed=seed,
aspect_ratio=validate_aspect_ratio(
aspect_ratio,
minimum_ratio=self.MINIMUM_RATIO,
maximum_ratio=self.MAXIMUM_RATIO,
minimum_ratio_str=self.MINIMUM_RATIO_STR,
maximum_ratio_str=self.MAXIMUM_RATIO_STR,
),
input_image=(
input_image
if input_image is None
else convert_image_to_base64(input_image)
)
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
class FluxKontextMaxImageNode(FluxKontextProImageNode):
"""
Edits images using Flux.1 Kontext [max] via api based on prompt and aspect ratio.
"""
DESCRIPTION = cleandoc(__doc__ or "")
BFL_PATH = "/proxy/bfl/flux-kontext-max/generate"
class FluxProImageNode(ComfyNodeABC):
"""
@@ -320,6 +492,7 @@ class FluxProImageNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -338,6 +511,7 @@ class FluxProImageNode(ComfyNodeABC):
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
unique_id: Union[str, None] = None,
**kwargs,
):
image_prompt = (
@@ -363,7 +537,7 @@ class FluxProImageNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -457,11 +631,11 @@ class FluxProExpandNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -483,6 +657,7 @@ class FluxProExpandNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
image = convert_image_to_base64(image)
@@ -508,7 +683,7 @@ class FluxProExpandNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -568,11 +743,11 @@ class FluxProFillNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -591,13 +766,14 @@ class FluxProFillNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
# prepare mask
mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed
image = convert_image_to_base64(image[:,:,:,:3])
image = convert_image_to_base64(image[:, :, :, :3])
operation = SynchronousOperation(
endpoint=ApiEndpoint(
@@ -617,7 +793,7 @@ class FluxProFillNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -702,11 +878,11 @@ class FluxProCannyNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -727,9 +903,10 @@ class FluxProCannyNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
control_image = convert_image_to_base64(control_image[:, :, :, :3])
preprocessed_image = None
# scale canny threshold between 0-500, to match BFL's API
@@ -765,7 +942,7 @@ class FluxProCannyNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -830,11 +1007,11 @@ class FluxProDepthNode(ComfyNodeABC):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -853,6 +1030,7 @@ class FluxProDepthNode(ComfyNodeABC):
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
@@ -880,7 +1058,7 @@ class FluxProDepthNode(ComfyNodeABC):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)
@@ -889,6 +1067,8 @@ class FluxProDepthNode(ComfyNodeABC):
NODE_CLASS_MAPPINGS = {
"FluxProUltraImageNode": FluxProUltraImageNode,
# "FluxProImageNode": FluxProImageNode,
"FluxKontextProImageNode": FluxKontextProImageNode,
"FluxKontextMaxImageNode": FluxKontextMaxImageNode,
"FluxProExpandNode": FluxProExpandNode,
"FluxProFillNode": FluxProFillNode,
"FluxProCannyNode": FluxProCannyNode,
@@ -899,6 +1079,8 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
"FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
# "FluxProImageNode": "Flux 1.1 [pro] Image",
"FluxKontextProImageNode": "Flux.1 Kontext [pro] Image",
"FluxKontextMaxImageNode": "Flux.1 Kontext [max] Image",
"FluxProExpandNode": "Flux.1 Expand Image",
"FluxProFillNode": "Flux.1 Fill Image",
"FluxProCannyNode": "Flux.1 Canny Control Image",

View File

@@ -0,0 +1,446 @@
"""
API Nodes for Gemini Multimodal LLM Usage via Remote API
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
import os
from enum import Enum
from typing import Optional, Literal
import torch
import folder_paths
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
from comfy_api_nodes.apis import (
GeminiContent,
GeminiGenerateContentRequest,
GeminiGenerateContentResponse,
GeminiInlineData,
GeminiPart,
GeminiMimeType,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
)
from comfy_api_nodes.apinode_utils import (
validate_string,
audio_to_base64_string,
video_to_base64_string,
tensor_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
class GeminiModel(str, Enum):
"""
Gemini Model Names allowed by comfy-api
"""
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
def get_gemini_endpoint(
model: GeminiModel,
) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
"""
Get the API endpoint for a given Gemini model.
Args:
model: The Gemini model to use, either as enum or string value.
Returns:
ApiEndpoint configured for the specific Gemini model.
"""
if isinstance(model, str):
model = GeminiModel(model)
return ApiEndpoint(
path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
method=HttpMethod.POST,
request_model=GeminiGenerateContentRequest,
response_model=GeminiGenerateContentResponse,
)
class GeminiNode(ComfyNodeABC):
"""
Node to generate text responses from a Gemini model.
This node allows users to interact with Google's Gemini AI models, providing
multimodal inputs (text, images, audio, video, files) to generate coherent
text responses. The node works with the latest Gemini models, handling the
API communication and response parsing.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
},
),
"model": (
IO.COMBO,
{
"tooltip": "The Gemini model to use for generating responses.",
"options": [model.value for model in GeminiModel],
"default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
},
),
"seed": (
IO.INT,
{
"default": 42,
"min": 0,
"max": 0xFFFFFFFFFFFFFFFF,
"control_after_generate": True,
"tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
},
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"audio": (
IO.AUDIO,
{
"tooltip": "Optional audio to use as context for the model.",
"default": None,
},
),
"video": (
IO.VIDEO,
{
"tooltip": "Optional video to use as context for the model.",
"default": None,
},
),
"files": (
"GEMINI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
RETURN_TYPES = ("STRING",)
FUNCTION = "api_call"
CATEGORY = "api node/text/Gemini"
API_NODE = True
def get_parts_from_response(
self, response: GeminiGenerateContentResponse
) -> list[GeminiPart]:
"""
Extract all parts from the Gemini API response.
Args:
response: The API response from Gemini.
Returns:
List of response parts from the first candidate.
"""
return response.candidates[0].content.parts
def get_parts_by_type(
self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
) -> list[GeminiPart]:
"""
Filter response parts by their type.
Args:
response: The API response from Gemini.
part_type: Type of parts to extract ("text" or a MIME type).
Returns:
List of response parts matching the requested type.
"""
parts = []
for part in self.get_parts_from_response(response):
if part_type == "text" and hasattr(part, "text") and part.text:
parts.append(part)
elif (
hasattr(part, "inlineData")
and part.inlineData
and part.inlineData.mimeType == part_type
):
parts.append(part)
# Skip parts that don't match the requested type
return parts
def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
"""
Extract and concatenate all text parts from the response.
Args:
response: The API response from Gemini.
Returns:
Combined text from all text parts in the response.
"""
parts = self.get_parts_by_type(response, "text")
return "\n".join([part.text for part in parts])
def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
"""
Convert video input to Gemini API compatible parts.
Args:
video_input: Video tensor from ComfyUI.
**kwargs: Additional arguments to pass to the conversion function.
Returns:
List of GeminiPart objects containing the encoded video.
"""
from comfy_api.util import VideoContainer, VideoCodec
base_64_string = video_to_base64_string(
video_input,
container_format=VideoContainer.MP4,
codec=VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = {
"waveform": audio_input["waveform"][batch_index].unsqueeze(0),
"sample_rate": audio_input["sample_rate"],
}
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
"""
Convert image tensor input to Gemini API compatible parts.
Args:
image_input: Batch of image tensors from ComfyUI.
Returns:
List of GeminiPart objects containing the encoded images.
"""
image_parts: list[GeminiPart] = []
for image_index in range(image_input.shape[0]):
image_as_b64 = tensor_to_base64_string(
image_input[image_index].unsqueeze(0)
)
image_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.image_png,
data=image_as_b64,
)
)
)
return image_parts
def create_text_part(self, text: str) -> GeminiPart:
"""
Create a text part for the Gemini API request.
Args:
text: The text content to include in the request.
Returns:
A GeminiPart object with the text content.
"""
return GeminiPart(text=text)
def api_call(
self,
prompt: str,
model: GeminiModel,
images: Optional[IO.IMAGE] = None,
audio: Optional[IO.AUDIO] = None,
video: Optional[IO.VIDEO] = None,
files: Optional[list[GeminiPart]] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [self.create_text_part(prompt)]
# Add other modal parts
if images is not None:
image_parts = self.create_image_parts(images)
parts.extend(image_parts)
if audio is not None:
parts.extend(self.create_audio_parts(audio))
if video is not None:
parts.extend(self.create_video_parts(video))
if files is not None:
parts.extend(files)
# Create response
response = SynchronousOperation(
endpoint=get_gemini_endpoint(model),
request=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role="user",
parts=parts,
)
]
),
auth_kwargs=kwargs,
).execute()
# Get result output
output_text = self.get_text_from_response(response)
if unique_id and output_text:
PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
return (output_text or "Empty response from Gemini model...",)
class GeminiInputFiles(ComfyNodeABC):
"""
Loads and formats input files for use with the Gemini API.
This node allows users to include text (.txt) and PDF (.pdf) files as input
context for the Gemini model. Files are converted to the appropriate format
required by the API and can be chained together to include multiple files
in a single request.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"GEMINI_INPUT_FILES": (
"GEMINI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
RETURN_TYPES = ("GEMINI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/Gemini"
def create_file_part(self, file_path: str) -> GeminiPart:
mime_type = (
GeminiMimeType.pdf
if file_path.endswith(".pdf")
else GeminiMimeType.text_plain
)
# Use base64 string directly, not the data URI
with open(file_path, "rb") as f:
file_content = f.read()
import base64
base64_str = base64.b64encode(file_content).decode("utf-8")
return GeminiPart(
inlineData=GeminiInlineData(
mimeType=mime_type,
data=base64_str,
)
)
def prepare_files(
self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
) -> tuple[list[GeminiPart]]:
"""
Loads and formats input files for Gemini API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_file_part(file_path)
files = [input_file_content] + GEMINI_INPUT_FILES
return (files,)
NODE_CLASS_MAPPINGS = {
"GeminiNode": GeminiNode,
"GeminiInputFiles": GeminiInputFiles,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"GeminiNode": "Google Gemini",
"GeminiInputFiles": "Gemini Input Files",
}

View File

@@ -23,6 +23,7 @@ from comfy_api_nodes.apinode_utils import (
bytesio_to_image_tensor,
resize_mask_to_image,
)
from server import PromptServer
V1_V1_RES_MAP = {
"Auto":"AUTO",
@@ -232,6 +233,19 @@ def download_and_process_images(image_urls):
return stacked_tensors
def display_image_urls_on_node(image_urls, node_id):
if node_id and image_urls:
if len(image_urls) == 1:
PromptServer.instance.send_progress_text(
f"Generated Image URL:\n{image_urls[0]}", node_id
)
else:
urls_text = "Generated Image URLs:\n" + "\n".join(
f"{i+1}. {url}" for i, url in enumerate(image_urls)
)
PromptServer.instance.send_progress_text(urls_text, node_id)
class IdeogramV1(ComfyNodeABC):
"""
Generates images using the Ideogram V1 model.
@@ -304,6 +318,7 @@ class IdeogramV1(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -322,6 +337,7 @@ class IdeogramV1(ComfyNodeABC):
seed=0,
negative_prompt="",
num_images=1,
unique_id=None,
**kwargs,
):
# Determine the model based on turbo setting
@@ -361,6 +377,7 @@ class IdeogramV1(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
@@ -460,6 +477,7 @@ class IdeogramV2(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -481,6 +499,7 @@ class IdeogramV2(ComfyNodeABC):
negative_prompt="",
num_images=1,
color_palette="",
unique_id=None,
**kwargs,
):
aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
@@ -534,6 +553,7 @@ class IdeogramV2(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
class IdeogramV3(ComfyNodeABC):
@@ -623,6 +643,7 @@ class IdeogramV3(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -643,6 +664,7 @@ class IdeogramV3(ComfyNodeABC):
seed=0,
num_images=1,
rendering_speed="BALANCED",
unique_id=None,
**kwargs,
):
# Check if both image and mask are provided for editing mode
@@ -762,6 +784,7 @@ class IdeogramV3(ComfyNodeABC):
if not image_urls:
raise Exception("No image URLs were generated in the response")
display_image_urls_on_node(image_urls, unique_id)
return (download_and_process_images(image_urls),)
@@ -776,4 +799,3 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"IdeogramV2": "Ideogram V2",
"IdeogramV3": "Ideogram V3",
}

View File

@@ -6,6 +6,7 @@ For source of truth on the allowed permutations of request fields, please refere
from __future__ import annotations
from typing import Optional, TypeVar, Any
from collections.abc import Callable
import math
import logging
@@ -64,6 +65,12 @@ from comfy_api_nodes.apinode_utils import (
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api_nodes.util.validation_utils import (
validate_image_dimensions,
validate_image_aspect_ratio,
validate_video_dimensions,
validate_video_duration,
)
from comfy_api.input.basic_types import AudioInput
from comfy_api.input.video_types import VideoInput
from comfy_api.input_impl import VideoFromFile
@@ -79,13 +86,20 @@ PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on"
PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
MAX_PROMPT_LENGTH_T2V = 2500
MAX_PROMPT_LENGTH_I2V = 500
MAX_PROMPT_LENGTH_IMAGE_GEN = 500
MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
MAX_PROMPT_LENGTH_LIP_SYNC = 120
AVERAGE_DURATION_T2V = 319
AVERAGE_DURATION_I2V = 164
AVERAGE_DURATION_LIP_SYNC = 455
AVERAGE_DURATION_VIRTUAL_TRY_ON = 19
AVERAGE_DURATION_IMAGE_GEN = 32
AVERAGE_DURATION_VIDEO_EFFECTS = 320
AVERAGE_DURATION_VIDEO_EXTEND = 320
R = TypeVar("R")
@@ -95,7 +109,13 @@ class KlingApiError(Exception):
pass
def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[Any, R]) -> R:
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, R],
result_url_extractor: Optional[Callable[[R], str]] = None,
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> R:
"""Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
@@ -109,6 +129,9 @@ def poll_until_finished(auth_kwargs: dict[str,str], api_endpoint: ApiEndpoint[An
else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=result_url_extractor,
estimated_duration=estimated_duration,
node_id=node_id,
).execute()
@@ -192,23 +215,8 @@ def validate_input_image(image: torch.Tensor) -> None:
See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
"""
if len(image.shape) == 4:
height, width = image.shape[1], image.shape[2]
elif len(image.shape) == 3:
height, width = image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
# Ensure minimum resolution is met
if height < 300:
raise ValueError("Image height must be at least 300px")
if width < 300:
raise ValueError("Image width must be at least 300px")
# Ensure aspect ratio is within acceptable range
aspect_ratio = width / height
if aspect_ratio < 1 / 2.5 or aspect_ratio > 2.5:
raise ValueError("Image aspect ratio must be between 1:2.5 and 2.5:1")
validate_image_dimensions(image, min_width=300, min_height=300)
validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
def get_camera_control_input_config(
@@ -227,7 +235,9 @@ def get_camera_control_input_config(
def get_video_from_response(response) -> KlingVideoResult:
"""Returns the first video object from the Kling video generation task result."""
"""Returns the first video object from the Kling video generation task result.
Will raise an error if the response is not valid.
"""
video = response.data.task_result.videos[0]
logging.info(
"Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
@@ -235,12 +245,37 @@ def get_video_from_response(response) -> KlingVideoResult:
return video
def get_video_url_from_response(response) -> Optional[str]:
"""Returns the first video url from the Kling video generation task result.
Will not raise an error if the response is not valid.
"""
if response and is_valid_video_response(response):
return str(get_video_from_response(response).url)
else:
return None
def get_images_from_response(response) -> list[KlingImageResult]:
"""Returns the list of image objects from the Kling image generation task result.
Will raise an error if the response is not valid.
"""
images = response.data.task_result.images
logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
return images
def get_images_urls_from_response(response) -> Optional[str]:
"""Returns the list of image urls from the Kling image generation task result.
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
"""
if response and is_valid_image_response(response):
images = get_images_from_response(response)
image_urls = [str(image.url) for image in images]
return "\n".join(image_urls)
else:
return None
def video_result_to_node_output(
video: KlingVideoResult,
) -> tuple[VideoFromFile, str, str]:
@@ -312,6 +347,7 @@ class KlingCameraControls(KlingNodeBase):
RETURN_TYPES = ("CAMERA_CONTROL",)
RETURN_NAMES = ("camera_control",)
FUNCTION = "main"
API_NODE = False # This is just a helper node, it doesn't make an API call
@classmethod
def VALIDATE_INPUTS(
@@ -421,6 +457,7 @@ class KlingTextToVideoNode(KlingNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -428,7 +465,9 @@ class KlingTextToVideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Text to Video Node"
def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingText2VideoResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingText2VideoResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
@@ -437,6 +476,9 @@ class KlingTextToVideoNode(KlingNodeBase):
request_model=EmptyRequest,
response_model=KlingText2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
node_id=node_id,
)
def api_call(
@@ -449,6 +491,7 @@ class KlingTextToVideoNode(KlingNodeBase):
camera_control: Optional[KlingCameraControl] = None,
model_name: Optional[str] = None,
duration: Optional[str] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
@@ -478,7 +521,9 @@ class KlingTextToVideoNode(KlingNodeBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -528,6 +573,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -540,6 +586,7 @@ class KlingCameraControlT2VNode(KlingTextToVideoNode):
cfg_scale: float,
aspect_ratio: str,
camera_control: Optional[KlingCameraControl] = None,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
@@ -613,6 +660,7 @@ class KlingImage2VideoNode(KlingNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -620,7 +668,9 @@ class KlingImage2VideoNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Image to Video Node"
def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingImage2VideoResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingImage2VideoResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
@@ -629,6 +679,9 @@ class KlingImage2VideoNode(KlingNodeBase):
request_model=KlingImage2VideoRequest,
response_model=KlingImage2VideoResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
node_id=node_id,
)
def api_call(
@@ -643,6 +696,7 @@ class KlingImage2VideoNode(KlingNodeBase):
duration: str,
camera_control: Optional[KlingCameraControl] = None,
end_frame: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
@@ -681,7 +735,9 @@ class KlingImage2VideoNode(KlingNodeBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -734,6 +790,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -747,6 +804,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
cfg_scale: float,
aspect_ratio: str,
camera_control: KlingCameraControl,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
@@ -759,6 +817,7 @@ class KlingCameraControlI2VNode(KlingImage2VideoNode):
prompt=prompt,
negative_prompt=negative_prompt,
camera_control=camera_control,
unique_id=unique_id,
**kwargs,
)
@@ -830,6 +889,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -844,6 +904,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
cfg_scale: float,
aspect_ratio: str,
mode: str,
unique_id: Optional[str] = None,
**kwargs,
):
mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
@@ -859,6 +920,7 @@ class KlingStartEndFrameNode(KlingImage2VideoNode):
aspect_ratio=aspect_ratio,
duration=duration,
end_frame=end_frame,
unique_id=unique_id,
**kwargs,
)
@@ -892,6 +954,7 @@ class KlingVideoExtendNode(KlingNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -899,7 +962,9 @@ class KlingVideoExtendNode(KlingNodeBase):
RETURN_NAMES = ("VIDEO", "video_id", "duration")
DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoExtendResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoExtendResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
@@ -908,6 +973,9 @@ class KlingVideoExtendNode(KlingNodeBase):
request_model=EmptyRequest,
response_model=KlingVideoExtendResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
node_id=node_id,
)
def api_call(
@@ -916,6 +984,7 @@ class KlingVideoExtendNode(KlingNodeBase):
negative_prompt: str,
cfg_scale: float,
video_id: str,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
@@ -939,7 +1008,9 @@ class KlingVideoExtendNode(KlingNodeBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -952,7 +1023,9 @@ class KlingVideoEffectsBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingVideoEffectsResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVideoEffectsResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
@@ -961,6 +1034,9 @@ class KlingVideoEffectsBase(KlingNodeBase):
request_model=EmptyRequest,
response_model=KlingVideoEffectsResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
node_id=node_id,
)
def api_call(
@@ -972,6 +1048,7 @@ class KlingVideoEffectsBase(KlingNodeBase):
image_1: torch.Tensor,
image_2: Optional[torch.Tensor] = None,
mode: Optional[KlingVideoGenMode] = None,
unique_id: Optional[str] = None,
**kwargs,
):
if dual_character:
@@ -1009,7 +1086,9 @@ class KlingVideoEffectsBase(KlingNodeBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -1053,6 +1132,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -1068,6 +1148,7 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
model_name: KlingCharacterEffectModelName,
mode: KlingVideoGenMode,
duration: KlingVideoGenDuration,
unique_id: Optional[str] = None,
**kwargs,
):
video, _, duration = super().api_call(
@@ -1078,10 +1159,12 @@ class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
duration=duration,
image_1=image_left,
image_2=image_right,
unique_id=unique_id,
**kwargs,
)
return video, duration
class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
"""Kling Single Image Video Effect Node"""
@@ -1117,6 +1200,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -1128,6 +1212,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
effect_scene: KlingSingleImageEffectsScene,
model_name: KlingSingleImageEffectModelName,
duration: KlingVideoGenDuration,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
@@ -1136,6 +1221,7 @@ class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
model_name=model_name,
duration=duration,
image_1=image,
unique_id=unique_id,
**kwargs,
)
@@ -1146,6 +1232,17 @@ class KlingLipSyncBase(KlingNodeBase):
RETURN_TYPES = ("VIDEO", "STRING", "STRING")
RETURN_NAMES = ("VIDEO", "video_id", "duration")
def validate_lip_sync_video(self, video: VideoInput):
"""
Validates the input video adheres to the expectations of the Kling Lip Sync API:
- Video length does not exceed 10s and is not shorter than 2s
- Length and width dimensions should both be between 720px and 1920px
See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip
"""
validate_video_dimensions(video, 720, 1920)
validate_video_duration(video, 2, 10)
def validate_text(self, text: str):
if not text:
raise ValueError("Text is required")
@@ -1154,7 +1251,9 @@ class KlingLipSyncBase(KlingNodeBase):
f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
)
def get_response(self, task_id: str, auth_kwargs: dict[str,str]) -> KlingLipSyncResponse:
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingLipSyncResponse:
"""Polls the Kling API endpoint until the task reaches a terminal state."""
return poll_until_finished(
auth_kwargs,
@@ -1164,6 +1263,9 @@ class KlingLipSyncBase(KlingNodeBase):
request_model=EmptyRequest,
response_model=KlingLipSyncResponse,
),
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_LIP_SYNC,
node_id=node_id,
)
def api_call(
@@ -1175,10 +1277,12 @@ class KlingLipSyncBase(KlingNodeBase):
text: Optional[str] = None,
voice_speed: Optional[float] = None,
voice_id: Optional[str] = None,
**kwargs
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile, str, str]:
if text:
self.validate_text(text)
self.validate_lip_sync_video(video)
# Upload video to Comfy API and get download URL
video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
@@ -1217,7 +1321,9 @@ class KlingLipSyncBase(KlingNodeBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_video_result_response(final_response)
video = get_video_from_response(final_response)
@@ -1243,16 +1349,18 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."
DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
self,
video: VideoInput,
audio: AudioInput,
voice_language: str,
unique_id: Optional[str] = None,
**kwargs,
):
return super().api_call(
@@ -1260,6 +1368,7 @@ class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
audio=audio,
voice_language=voice_language,
mode="audio2video",
unique_id=unique_id,
**kwargs,
)
@@ -1352,10 +1461,11 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."
DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
def api_call(
self,
@@ -1363,6 +1473,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
text: str,
voice: str,
voice_speed: float,
unique_id: Optional[str] = None,
**kwargs,
):
voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
@@ -1373,6 +1484,7 @@ class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
voice_id=voice_id,
voice_speed=voice_speed,
mode="text2video",
unique_id=unique_id,
**kwargs,
)
@@ -1413,13 +1525,14 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human."
DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
def get_response(
self, task_id: str, auth_kwargs: dict[str,str] = None
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> KlingVirtualTryOnResponse:
return poll_until_finished(
auth_kwargs,
@@ -1429,6 +1542,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
request_model=EmptyRequest,
response_model=KlingVirtualTryOnResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
node_id=node_id,
)
def api_call(
@@ -1436,6 +1552,7 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
human_image: torch.Tensor,
cloth_image: torch.Tensor,
model_name: KlingVirtualTryOnModelName,
unique_id: Optional[str] = None,
**kwargs,
):
initial_operation = SynchronousOperation(
@@ -1457,7 +1574,9 @@ class KlingVirtualTryOnNode(KlingImageGenerationBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)
@@ -1528,13 +1647,17 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
def get_response(
self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]],
node_id: Optional[str] = None,
) -> KlingImageGenerationsResponse:
return poll_until_finished(
auth_kwargs,
@@ -1544,6 +1667,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
request_model=EmptyRequest,
response_model=KlingImageGenerationsResponse,
),
result_url_extractor=get_images_urls_from_response,
estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
node_id=node_id,
)
def api_call(
@@ -1557,6 +1683,7 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
n: int,
aspect_ratio: KlingImageGenAspectRatio,
image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
):
self.validate_prompt(prompt, negative_prompt)
@@ -1589,7 +1716,9 @@ class KlingImageGenerationNode(KlingImageGenerationBase):
validate_task_creation_response(task_creation_response)
task_id = task_creation_response.data.task_id
final_response = self.get_response(task_id, auth_kwargs=kwargs)
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
validate_image_result_response(final_response)
images = get_images_from_response(final_response)

View File

@@ -36,11 +36,20 @@ from comfy_api_nodes.apinode_utils import (
process_image_response,
validate_string,
)
from server import PromptServer
import requests
import torch
from io import BytesIO
LUMA_T2V_AVERAGE_DURATION = 105
LUMA_I2V_AVERAGE_DURATION = 100
def image_result_url_extractor(response: LumaGeneration):
return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
def video_result_url_extractor(response: LumaGeneration):
return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
class LumaReferenceNode(ComfyNodeABC):
"""
@@ -204,6 +213,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -217,6 +227,7 @@ class LumaImageGenerationNode(ComfyNodeABC):
image_luma_ref: LumaReferenceChain = None,
style_image: torch.Tensor = None,
character_image: torch.Tensor = None,
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=True, min_length=3)
@@ -271,6 +282,8 @@ class LumaImageGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -353,6 +366,7 @@ class LumaImageModifyNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -363,6 +377,7 @@ class LumaImageModifyNode(ComfyNodeABC):
image: torch.Tensor,
image_weight: float,
seed,
unique_id: str = None,
**kwargs,
):
# first, upload image
@@ -399,6 +414,8 @@ class LumaImageModifyNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=image_result_url_extractor,
node_id=unique_id,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -473,6 +490,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -486,6 +504,7 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
loop: bool,
seed,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, min_length=3)
@@ -512,6 +531,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
)
response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
@@ -522,6 +544,9 @@ class LumaTextToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_T2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()
@@ -597,6 +622,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -611,6 +637,7 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
first_image: torch.Tensor = None,
last_image: torch.Tensor = None,
luma_concepts: LumaConceptChain = None,
unique_id: str = None,
**kwargs,
):
if first_image is None and last_image is None:
@@ -642,6 +669,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
)
response_api: LumaGeneration = operation.execute()
if unique_id:
PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/luma/generations/{response_api.id}",
@@ -652,6 +682,9 @@ class LumaImageToVideoGenerationNode(ComfyNodeABC):
completed_statuses=[LumaState.completed],
failed_statuses=[LumaState.failed],
status_extractor=lambda x: x.state,
result_url_extractor=video_result_url_extractor,
node_id=unique_id,
estimated_duration=LUMA_I2V_AVERAGE_DURATION,
auth_kwargs=kwargs,
)
response_poll = operation.execute()

View File

@@ -1,3 +1,7 @@
from typing import Union
import logging
import torch
from comfy.comfy_types.node_typing import IO
from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api_nodes.apis import (
@@ -20,16 +24,19 @@ from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
validate_string,
)
from server import PromptServer
import torch
import logging
I2V_AVERAGE_DURATION = 114
T2V_AVERAGE_DURATION = 234
class MinimaxTextToVideoNode:
"""
Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -68,6 +75,7 @@ class MinimaxTextToVideoNode:
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -85,6 +93,7 @@ class MinimaxTextToVideoNode:
model="T2V-01",
image: torch.Tensor=None, # used for ImageToVideo
subject: torch.Tensor=None, # used for SubjectToVideo
unique_id: Union[str, None]=None,
**kwargs,
):
'''
@@ -138,6 +147,8 @@ class MinimaxTextToVideoNode:
completed_statuses=["Success"],
failed_statuses=["Fail"],
status_extractor=lambda x: x.status.value,
estimated_duration=self.AVERAGE_DURATION,
node_id=unique_id,
auth_kwargs=kwargs,
)
task_result = video_generate_operation.execute()
@@ -164,6 +175,12 @@ class MinimaxTextToVideoNode:
f"No video was found in the response. Full response: {file_result.model_dump()}"
)
logging.info(f"Generated video URL: {file_url}")
if unique_id:
if hasattr(file_result.file, "backup_download_url"):
message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
else:
message = f"Result URL: {file_url}"
PromptServer.instance.send_progress_text(message, unique_id)
video_io = download_url_to_bytesio(file_url)
if video_io is None:
@@ -178,6 +195,8 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = I2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -223,6 +242,7 @@ class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -239,6 +259,8 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
"""
AVERAGE_DURATION = T2V_AVERAGE_DURATION
@classmethod
def INPUT_TYPES(s):
return {
@@ -282,6 +304,7 @@ class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

View File

@@ -1,29 +1,86 @@
import io
from typing import TypedDict, Optional
import json
import os
import time
import re
import uuid
from enum import Enum
from inspect import cleandoc
import numpy as np
import torch
from PIL import Image
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
from server import PromptServer
import folder_paths
from comfy_api_nodes.apis import (
OpenAIImageGenerationRequest,
OpenAIImageEditRequest,
OpenAIImageGenerationResponse,
OpenAICreateResponse,
OpenAIResponse,
CreateModelResponseProperties,
Item,
Includable,
OutputContent,
InputImageContent,
Detail,
InputTextContent,
InputMessage,
InputMessageContentList,
InputContent,
InputFileContent,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
downscale_image_tensor,
validate_and_cast_response,
validate_string,
tensor_to_base64_string,
text_filepath_to_data_uri,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
STARTING_POINT_ID_PATTERN = r"<starting_point_id:(.*)>"
class HistoryEntry(TypedDict):
"""Type definition for a single history entry in the chat."""
prompt: str
response: str
response_id: str
timestamp: float
class ChatHistory(TypedDict):
"""Type definition for the chat history dictionary."""
__annotations__: dict[str, list[HistoryEntry]]
class SupportedOpenAIModel(str, Enum):
o4_mini = "o4-mini"
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
class OpenAIDalle2(ComfyNodeABC):
"""
@@ -96,6 +153,7 @@ class OpenAIDalle2(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -113,7 +171,8 @@ class OpenAIDalle2(ComfyNodeABC):
mask=None,
n=1,
size="1024x1024",
**kwargs
unique_id=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-2"
@@ -176,7 +235,7 @@ class OpenAIDalle2(ComfyNodeABC):
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@@ -242,6 +301,7 @@ class OpenAIDalle3(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -258,7 +318,8 @@ class OpenAIDalle3(ComfyNodeABC):
style="natural",
quality="standard",
size="1024x1024",
**kwargs
unique_id=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "dall-e-3"
@@ -284,7 +345,7 @@ class OpenAIDalle3(ComfyNodeABC):
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
@@ -375,6 +436,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -394,12 +456,13 @@ class OpenAIGPTImage1(ComfyNodeABC):
mask=None,
n=1,
size="1024x1024",
**kwargs
unique_id=None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
model = "gpt-image-1"
path = "/proxy/openai/images/generations"
content_type="application/json"
content_type = "application/json"
request_class = OpenAIImageGenerationRequest
img_binaries = []
mask_binary = None
@@ -408,7 +471,7 @@ class OpenAIGPTImage1(ComfyNodeABC):
if image is not None:
path = "/proxy/openai/images/edits"
request_class = OpenAIImageEditRequest
content_type ="multipart/form-data"
content_type = "multipart/form-data"
batch_size = image.shape[0]
@@ -476,21 +539,470 @@ class OpenAIGPTImage1(ComfyNodeABC):
response = operation.execute()
img_tensor = validate_and_cast_response(response)
img_tensor = validate_and_cast_response(response, node_id=unique_id)
return (img_tensor,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
class OpenAITextNode(ComfyNodeABC):
"""
Base class for OpenAI text generation nodes.
"""
RETURN_TYPES = (IO.STRING,)
FUNCTION = "api_call"
CATEGORY = "api node/text/OpenAI"
API_NODE = True
class OpenAIChatNode(OpenAITextNode):
"""
Node to generate text responses from an OpenAI model.
"""
def __init__(self) -> None:
"""Initialize the chat node with a new session ID and empty history."""
self.current_session_id: str = str(uuid.uuid4())
self.history: dict[str, list[HistoryEntry]] = {}
self.previous_response_id: Optional[str] = None
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"prompt": (
IO.STRING,
{
"multiline": True,
"default": "",
"tooltip": "Text inputs to the model, used to generate a response.",
},
),
"persist_context": (
IO.BOOLEAN,
{
"default": True,
"tooltip": "Persist chat context between calls (multi-turn conversation)",
},
),
"model": model_field_to_node_input(
IO.COMBO,
OpenAICreateResponse,
"model",
enum_type=SupportedOpenAIModel,
),
},
"optional": {
"images": (
IO.IMAGE,
{
"default": None,
"tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
},
),
"files": (
"OPENAI_INPUT_FILES",
{
"default": None,
"tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
},
),
"advanced_options": (
"OPENAI_CHAT_CONFIG",
{
"default": None,
"tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Generate text responses from an OpenAI model."
def get_result_response(
self,
response_id: str,
include: Optional[list[Includable]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
) -> OpenAIResponse:
"""
Retrieve a model response with the given ID from the OpenAI API.
Args:
response_id (str): The ID of the response to retrieve.
include (Optional[List[Includable]]): Additional fields to include
in the response. See the `include` parameter for Response
creation above for more information.
"""
return PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"{RESPONSES_ENDPOINT}/{response_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=OpenAIResponse,
query_params={"include": include},
),
completed_statuses=["completed"],
failed_statuses=["failed"],
status_extractor=lambda response: response.status,
auth_kwargs=auth_kwargs,
).execute()
def get_message_content_from_response(
self, response: OpenAIResponse
) -> list[OutputContent]:
"""Extract message content from the API response."""
for output in response.output:
if output.root.type == "message":
return output.root.content
raise TypeError("No output message found in response")
def get_text_from_message_content(
self, message_content: list[OutputContent]
) -> str:
"""Extract text content from message content."""
for content_item in message_content:
if content_item.root.type == "output_text":
return str(content_item.root.text)
return "No text output found in response"
def get_history_text(self, session_id: str) -> str:
"""Convert the entire history for a given session to JSON string."""
return json.dumps(self.history[session_id])
def display_history_on_node(self, session_id: str, node_id: str) -> None:
"""Display formatted chat history on the node UI."""
render_spec = {
"node_id": node_id,
"component": "ChatHistoryWidget",
"props": {
"history": self.get_history_text(session_id),
},
}
PromptServer.instance.send_sync(
"display_component",
render_spec,
)
def add_to_history(
self, session_id: str, prompt: str, output_text: str, response_id: str
) -> None:
"""Add a new entry to the chat history."""
if session_id not in self.history:
self.history[session_id] = []
self.history[session_id].append(
{
"prompt": prompt,
"response": output_text,
"response_id": response_id,
"timestamp": time.time(),
}
)
def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
"""Extract text output from the API response."""
message_contents = self.get_message_content_from_response(response)
return self.get_text_from_message_content(message_contents)
def generate_new_session_id(self) -> str:
"""Generate a new unique session ID."""
return str(uuid.uuid4())
def get_session_id(self, persist_context: bool) -> str:
"""Get the current or generate a new session ID based on context persistence."""
return (
self.current_session_id
if persist_context
else self.generate_new_session_id()
)
def tensor_to_input_image_content(
self, image: torch.Tensor, detail_level: Detail = "auto"
) -> InputImageContent:
"""Convert a tensor to an input image content object."""
return InputImageContent(
detail=detail_level,
image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
type="input_image",
)
def create_input_message_contents(
self,
prompt: str,
image: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
) -> InputMessageContentList:
"""Create a list of input message contents from prompt and optional image."""
content_list: list[InputContent] = [
InputTextContent(text=prompt, type="input_text"),
]
if image is not None:
for i in range(image.shape[0]):
content_list.append(
self.tensor_to_input_image_content(image[i].unsqueeze(0))
)
if files is not None:
content_list.extend(files)
return InputMessageContentList(
root=content_list,
)
def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
"""Extract response ID from prompt if it exists."""
parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
return parsed_id.group(1) if parsed_id else None
def strip_response_tag_from_prompt(self, prompt: str) -> str:
"""Remove the response ID tag from the prompt."""
return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
def delete_history_after_response_id(
self, new_start_id: str, session_id: str
) -> None:
"""Delete history entries after a specific response ID."""
if session_id not in self.history:
return
new_history = []
i = 0
while (
i < len(self.history[session_id])
and self.history[session_id][i]["response_id"] != new_start_id
):
new_history.append(self.history[session_id][i])
i += 1
# Since it's the new starting point (not the response being edited), we include it as well
if i < len(self.history[session_id]):
new_history.append(self.history[session_id][i])
self.history[session_id] = new_history
def api_call(
self,
prompt: str,
persist_context: bool,
model: SupportedOpenAIModel,
unique_id: Optional[str] = None,
images: Optional[torch.Tensor] = None,
files: Optional[list[InputFileContent]] = None,
advanced_options: Optional[CreateModelResponseProperties] = None,
**kwargs,
) -> tuple[str]:
# Validate inputs
validate_string(prompt, strip_whitespace=False)
session_id = self.get_session_id(persist_context)
response_id_override = self.parse_response_id_from_prompt(prompt)
if response_id_override:
is_starting_from_beginning = response_id_override == "start"
if is_starting_from_beginning:
self.history[session_id] = []
previous_response_id = None
else:
previous_response_id = response_id_override
self.delete_history_after_response_id(response_id_override, session_id)
prompt = self.strip_response_tag_from_prompt(prompt)
elif persist_context:
previous_response_id = self.previous_response_id
else:
previous_response_id = None
# Create response
create_response = SynchronousOperation(
endpoint=ApiEndpoint(
path=RESPONSES_ENDPOINT,
method=HttpMethod.POST,
request_model=OpenAICreateResponse,
response_model=OpenAIResponse,
),
request=OpenAICreateResponse(
input=[
Item(
root=InputMessage(
content=self.create_input_message_contents(
prompt, images, files
),
role="user",
)
),
],
store=True,
stream=False,
model=model,
previous_response_id=previous_response_id,
**(
advanced_options.model_dump(exclude_none=True)
if advanced_options
else {}
),
),
auth_kwargs=kwargs,
).execute()
response_id = create_response.id
# Get result output
result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
output_text = self.parse_output_text_from_response(result_response)
# Update history
self.add_to_history(session_id, prompt, output_text, response_id)
self.display_history_on_node(session_id, unique_id)
self.previous_response_id = response_id
return (output_text,)
class OpenAIInputFiles(ComfyNodeABC):
"""
Loads and formats input files for OpenAI API.
"""
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
"""
For details about the supported file input types, see:
https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
"""
input_dir = folder_paths.get_input_directory()
input_files = [
f
for f in os.scandir(input_dir)
if f.is_file()
and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
and f.stat().st_size < 32 * 1024 * 1024
]
input_files = sorted(input_files, key=lambda x: x.name)
input_files = [f.name for f in input_files]
return {
"required": {
"file": (
IO.COMBO,
{
"tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
"options": input_files,
"default": input_files[0] if input_files else None,
},
),
},
"optional": {
"OPENAI_INPUT_FILES": (
"OPENAI_INPUT_FILES",
{
"tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
"default": None,
},
),
},
}
DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
RETURN_TYPES = ("OPENAI_INPUT_FILES",)
FUNCTION = "prepare_files"
CATEGORY = "api node/text/OpenAI"
def create_input_file_content(self, file_path: str) -> InputFileContent:
return InputFileContent(
file_data=text_filepath_to_data_uri(file_path),
filename=os.path.basename(file_path),
type="input_file",
)
def prepare_files(
self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
) -> tuple[list[InputFileContent]]:
"""
Loads and formats input files for OpenAI API.
"""
file_path = folder_paths.get_annotated_filepath(file)
input_file_content = self.create_input_file_content(file_path)
files = [input_file_content] + OPENAI_INPUT_FILES
return (files,)
class OpenAIChatConfig(ComfyNodeABC):
"""Allows setting additional configuration for the OpenAI Chat Node."""
RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
FUNCTION = "configure"
DESCRIPTION = (
"Allows specifying advanced configuration options for the OpenAI Chat Nodes."
)
CATEGORY = "api node/text/OpenAI"
@classmethod
def INPUT_TYPES(cls) -> InputTypeDict:
return {
"required": {
"truncation": (
IO.COMBO,
{
"options": ["auto", "disabled"],
"default": "auto",
"tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
},
),
},
"optional": {
"max_output_tokens": model_field_to_node_input(
IO.INT,
OpenAICreateResponse,
"max_output_tokens",
min=16,
default=4096,
max=16384,
tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
),
"instructions": model_field_to_node_input(
IO.STRING, OpenAICreateResponse, "instructions", multiline=True
),
},
}
def configure(
self,
truncation: bool,
instructions: Optional[str] = None,
max_output_tokens: Optional[int] = None,
) -> tuple[CreateModelResponseProperties]:
"""
Configure advanced options for the OpenAI Chat Node.
Note:
While `top_p` and `temperature` are listed as properties in the
spec, they are not supported for all models (e.g., o4-mini).
They are not exposed as inputs at all to avoid having to manually
remove depending on model choice.
"""
return (
CreateModelResponseProperties(
instructions=instructions,
truncation=truncation,
max_output_tokens=max_output_tokens,
),
)
NODE_CLASS_MAPPINGS = {
"OpenAIDalle2": OpenAIDalle2,
"OpenAIDalle3": OpenAIDalle3,
"OpenAIGPTImage1": OpenAIGPTImage1,
"OpenAIChatNode": OpenAIChatNode,
"OpenAIInputFiles": OpenAIInputFiles,
"OpenAIChatConfig": OpenAIChatConfig,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"OpenAIDalle2": "OpenAI DALL·E 2",
"OpenAIDalle3": "OpenAI DALL·E 3",
"OpenAIGPTImage1": "OpenAI GPT Image 1",
"OpenAIChatNode": "OpenAI Chat",
"OpenAIInputFiles": "OpenAI Chat Input Files",
"OpenAIChatConfig": "OpenAI Chat Advanced Options",
}

View File

@@ -6,40 +6,42 @@ Pika API docs: https://pika-827374fb.mintlify.app/api-reference
from __future__ import annotations
import io
from typing import Optional, TypeVar
import logging
import torch
from typing import Optional, TypeVar
import numpy as np
import torch
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
from comfy_api.input_impl import VideoFromFile
from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput
from comfy_api_nodes.apinode_utils import (
download_url_to_video_output,
tensor_to_bytesio,
)
from comfy_api_nodes.apis import (
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaGenerateResponse,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaVideoResponse,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
IngredientsMode,
PikaDurationEnum,
PikaResolutionEnum,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaBodyGenerate22C2vGenerate22PikascenesPost,
PikaBodyGenerate22I2vGenerate22I2vPost,
PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
PikaBodyGenerate22T2vGenerate22T2vPost,
PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
PikaDurationEnum,
Pikaffect,
PikaGenerateResponse,
PikaResolutionEnum,
PikaVideoResponse,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
tensor_to_bytesio,
download_url_to_video_output,
HttpMethod,
PollingOperation,
SynchronousOperation,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
R = TypeVar("R")
@@ -121,7 +123,10 @@ class PikaNodeBase(ComfyNodeABC):
RETURN_TYPES = ("VIDEO",)
def poll_for_task_status(
self, task_id: str, auth_kwargs: Optional[dict[str,str]] = None
self,
task_id: str,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> PikaGenerateResponse:
polling_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
@@ -141,13 +146,19 @@ class PikaNodeBase(ComfyNodeABC):
response.progress if hasattr(response, "progress") else None
),
auth_kwargs=auth_kwargs,
result_url_extractor=lambda response: (
response.url if hasattr(response, "url") else None
),
node_id=node_id,
estimated_duration=60
)
return polling_operation.execute()
def execute_task(
self,
initial_operation: SynchronousOperation[R, PikaGenerateResponse],
auth_kwargs: Optional[dict[str,str]] = None,
auth_kwargs: Optional[dict[str, str]] = None,
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
"""Executes the initial operation then polls for the task status until it is completed.
@@ -195,6 +206,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -208,7 +220,8 @@ class PikaImageToVideoV2_2(PikaNodeBase):
seed: int,
resolution: str,
duration: int,
**kwargs
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert image to BytesIO
image_bytes_io = tensor_to_bytesio(image)
@@ -238,7 +251,7 @@ class PikaImageToVideoV2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaTextToVideoNodeV2_2(PikaNodeBase):
@@ -262,6 +275,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -275,6 +289,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
resolution: str,
duration: int,
aspect_ratio: float,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
@@ -296,7 +311,7 @@ class PikaTextToVideoNodeV2_2(PikaNodeBase):
content_type="application/x-www-form-urlencoded",
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaScenesV2_2(PikaNodeBase):
@@ -340,6 +355,7 @@ class PikaScenesV2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -354,6 +370,7 @@ class PikaScenesV2_2(PikaNodeBase):
duration: int,
ingredients_mode: str,
aspect_ratio: float,
unique_id: str,
image_ingredient_1: Optional[torch.Tensor] = None,
image_ingredient_2: Optional[torch.Tensor] = None,
image_ingredient_3: Optional[torch.Tensor] = None,
@@ -403,7 +420,7 @@ class PikaScenesV2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikAdditionsNode(PikaNodeBase):
@@ -439,10 +456,11 @@ class PikAdditionsNode(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what youd like to add to create a seamlessly integrated result."
DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result."
def api_call(
self,
@@ -451,6 +469,7 @@ class PikAdditionsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
@@ -487,7 +506,7 @@ class PikAdditionsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaSwapsNode(PikaNodeBase):
@@ -532,6 +551,7 @@ class PikaSwapsNode(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -546,6 +566,7 @@ class PikaSwapsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
# Convert video to BytesIO
@@ -592,7 +613,7 @@ class PikaSwapsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaffectsNode(PikaNodeBase):
@@ -637,6 +658,7 @@ class PikaffectsNode(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -649,6 +671,7 @@ class PikaffectsNode(PikaNodeBase):
prompt_text: str,
negative_prompt: str,
seed: int,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
@@ -670,7 +693,7 @@ class PikaffectsNode(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
class PikaStartEndFrameNode2_2(PikaNodeBase):
@@ -689,6 +712,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -703,6 +727,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
seed: int,
resolution: str,
duration: int,
unique_id: str,
**kwargs,
) -> tuple[VideoFromFile]:
@@ -733,7 +758,7 @@ class PikaStartEndFrameNode2_2(PikaNodeBase):
auth_kwargs=kwargs,
)
return self.execute_task(initial_operation, auth_kwargs=kwargs)
return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
NODE_CLASS_MAPPINGS = {

View File

@@ -1,5 +1,5 @@
from inspect import cleandoc
from typing import Optional
from comfy_api_nodes.apis.pixverse_api import (
PixverseTextVideoRequest,
PixverseImageVideoRequest,
@@ -34,11 +34,22 @@ import requests
from io import BytesIO
AVERAGE_DURATION_T2V = 32
AVERAGE_DURATION_I2V = 30
AVERAGE_DURATION_T2T = 52
def get_video_url_from_response(
response: PixverseGenerationStatusResponse,
) -> Optional[str]:
if response.Resp is None or response.Resp.url is None:
return None
return str(response.Resp.url)
def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
# first, upload image to Pixverse and get image id to use in actual generation call
files = {
"image": tensor_to_bytesio(image)
}
files = {"image": tensor_to_bytesio(image)}
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/pixverse/image/upload",
@@ -54,7 +65,9 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
response_upload: PixverseImageUploadResponse = operation.execute()
if response_upload.Resp is None:
raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'")
raise Exception(
f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
)
return response_upload.Resp.img_id
@@ -73,7 +86,7 @@ class PixverseTemplateNode:
def INPUT_TYPES(s):
return {
"required": {
"template": (list(pixverse_templates.keys()), ),
"template": (list(pixverse_templates.keys()),),
}
}
@@ -87,7 +100,7 @@ class PixverseTemplateNode:
class PixverseTextToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -108,9 +121,7 @@ class PixverseTextToVideoNode(ComfyNodeABC):
"tooltip": "Prompt for the video generation",
},
),
"aspect_ratio": (
[ratio.value for ratio in PixverseAspectRatio],
),
"aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
"quality": (
[resolution.value for resolution in PixverseQuality],
{
@@ -143,12 +154,13 @@ class PixverseTextToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -160,8 +172,9 @@ class PixverseTextToVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
@@ -205,19 +218,27 @@ class PixverseTextToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()
vid_response = requests.get(response_poll.Resp.url)
return (VideoFromFile(BytesIO(vid_response.content)),)
class PixverseImageToVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -230,9 +251,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
def INPUT_TYPES(s):
return {
"required": {
"image": (
IO.IMAGE,
),
"image": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
@@ -273,12 +292,13 @@ class PixverseImageToVideoNode(ComfyNodeABC):
PixverseIO.TEMPLATE,
{
"tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
}
)
},
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -290,8 +310,9 @@ class PixverseImageToVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
pixverse_template: int=None,
negative_prompt: str = None,
pixverse_template: int = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
@@ -337,9 +358,16 @@ class PixverseImageToVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_I2V,
)
response_poll = operation.execute()
@@ -349,7 +377,7 @@ class PixverseImageToVideoNode(ComfyNodeABC):
class PixverseTransitionVideoNode(ComfyNodeABC):
"""
Generates videos synchronously based on prompt and output_size.
Generates videos based on prompt and output_size.
"""
RETURN_TYPES = (IO.VIDEO,)
@@ -362,12 +390,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
def INPUT_TYPES(s):
return {
"required": {
"first_frame": (
IO.IMAGE,
),
"last_frame": (
IO.IMAGE,
),
"first_frame": (IO.IMAGE,),
"last_frame": (IO.IMAGE,),
"prompt": (
IO.STRING,
{
@@ -408,6 +432,7 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -420,7 +445,8 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
duration_seconds: int,
motion_mode: str,
seed,
negative_prompt: str=None,
negative_prompt: str = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False)
@@ -467,9 +493,16 @@ class PixverseTransitionVideoNode(ComfyNodeABC):
response_model=PixverseGenerationStatusResponse,
),
completed_statuses=[PixverseStatus.successful],
failed_statuses=[PixverseStatus.contents_moderation, PixverseStatus.failed, PixverseStatus.deleted],
failed_statuses=[
PixverseStatus.contents_moderation,
PixverseStatus.failed,
PixverseStatus.deleted,
],
status_extractor=lambda x: x.Resp.status,
auth_kwargs=kwargs,
node_id=unique_id,
result_url_extractor=get_video_url_from_response,
estimated_duration=AVERAGE_DURATION_T2V,
)
response_poll = operation.execute()

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from inspect import cleandoc
from typing import Optional
from comfy.utils import ProgressBar
from comfy_extras.nodes_images import SVG # Added
from comfy.comfy_types.node_typing import IO
@@ -29,6 +30,8 @@ from comfy_api_nodes.apinode_utils import (
resize_mask_to_image,
validate_string,
)
from server import PromptServer
import torch
from io import BytesIO
from PIL import UnidentifiedImageError
@@ -388,6 +391,7 @@ class RecraftTextToImageNode:
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -400,6 +404,7 @@ class RecraftTextToImageNode:
recraft_style: RecraftStyle = None,
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, max_length=1000)
@@ -436,8 +441,15 @@ class RecraftTextToImageNode:
)
response: RecraftImageGenerationResponse = operation.execute()
images = []
urls = []
for data in response.data:
with handle_recraft_image_output():
if unique_id and data.url:
urls.append(data.url)
urls_string = '\n'.join(urls)
PromptServer.instance.send_progress_text(
f"Result URL: {urls_string}", unique_id
)
image = bytesio_to_image_tensor(
download_url_to_bytesio(data.url, timeout=1024)
)
@@ -763,6 +775,7 @@ class RecraftTextToVectorNode:
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -775,6 +788,7 @@ class RecraftTextToVectorNode:
seed,
negative_prompt: str = None,
recraft_controls: RecraftControls = None,
unique_id: Optional[str] = None,
**kwargs,
):
validate_string(prompt, strip_whitespace=False, max_length=1000)
@@ -809,7 +823,14 @@ class RecraftTextToVectorNode:
)
response: RecraftImageGenerationResponse = operation.execute()
svg_data = []
urls = []
for data in response.data:
if unique_id and data.url:
urls.append(data.url)
# Print result on each iteration in case of error
PromptServer.instance.send_progress_text(
f"Result URL: {' '.join(urls)}", unique_id
)
svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
return (SVG(svg_data),)

View File

@@ -0,0 +1,462 @@
"""
ComfyUI X Rodin3D(Deemos) API Nodes
Rodin API docs: https://developer.hyper3d.ai/
"""
from __future__ import annotations
from inspect import cleandoc
from comfy.comfy_types.node_typing import IO
import folder_paths as comfy_paths
import requests
import os
import datetime
import shutil
import time
import io
import logging
import math
from PIL import Image
from comfy_api_nodes.apis.rodin_api import (
Rodin3DGenerateRequest,
Rodin3DGenerateResponse,
Rodin3DCheckStatusRequest,
Rodin3DCheckStatusResponse,
Rodin3DDownloadRequest,
Rodin3DDownloadResponse,
JobStatus,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
)
COMMON_PARAMETERS = {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
"default": "18K-Quad"
}
)
}
def create_task_error(response: Rodin3DGenerateResponse):
"""Check if the response has error"""
return hasattr(response, "error")
class Rodin3DAPI:
"""
Generate 3D Assets using Rodin API
"""
RETURN_TYPES = (IO.STRING,)
RETURN_NAMES = ("3D Model Path",)
CATEGORY = "api node/3d/Rodin"
DESCRIPTION = cleandoc(__doc__ or "")
FUNCTION = "api_call"
API_NODE = True
def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
"""
Converts a PyTorch tensor to a file-like object.
Args:
- tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
where C is the number of channels (3 for RGB), H is height, and W is width.
Returns:
- io.BytesIO: A file-like object containing the image data.
"""
array = tensor.cpu().numpy()
array = (array * 255).astype('uint8')
image = Image.fromarray(array, 'RGB')
original_width, original_height = image.size
original_pixels = original_width * original_height
if original_pixels > max_pixels:
scale = math.sqrt(max_pixels / original_pixels)
new_width = int(original_width * scale)
new_height = int(original_height * scale)
else:
new_width, new_height = original_width, original_height
if new_width != original_width or new_height != original_height:
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
img_byte_arr.seek(0)
return img_byte_arr
def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
all_done = all(job.status == JobStatus.Done for job in response.jobs)
status_list = [str(job.status) for job in response.jobs]
logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
if has_failed:
logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
elif all_done:
return "DONE"
else:
return "Generating"
def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
if images == None:
raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5:
raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DGenerateRequest,
response_model=Rodin3DGenerateResponse,
),
request=Rodin3DGenerateRequest(
seed=seed,
tier=tier,
material=material,
quality=quality,
mesh_mode=mesh_mode
),
files=[
(
"images",
open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
)
for image in images if image is not None
],
content_type = "multipart/form-data",
auth_kwargs=kwargs,
)
response = operation.execute()
if create_task_error(response):
error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
logging.error(error_message)
raise Exception(error_message)
logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
subscription_key = response.jobs.subscription_key
task_uuid = response.uuid
logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
return task_uuid, subscription_key
def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
path = "/proxy/rodin/api/v2/status"
poll_operation = PollingOperation(
poll_endpoint=ApiEndpoint(
path = path,
method=HttpMethod.POST,
request_model=Rodin3DCheckStatusRequest,
response_model=Rodin3DCheckStatusResponse,
),
request=Rodin3DCheckStatusRequest(
subscription_key = subscription_key
),
completed_statuses=["DONE"],
failed_statuses=["FAILED"],
status_extractor=self.check_rodin_status,
poll_interval=3.0,
auth_kwargs=kwargs,
)
logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
return poll_operation.execute()
def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
path = "/proxy/rodin/api/v2/download"
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=path,
method=HttpMethod.POST,
request_model=Rodin3DDownloadRequest,
response_model=Rodin3DDownloadResponse,
),
request=Rodin3DDownloadRequest(
task_uuid=uuid
),
auth_kwargs=kwargs
)
return operation.execute()
def GetQualityAndMode(self, PolyCount):
if PolyCount == "200K-Triangle":
mesh_mode = "Raw"
quality = "medium"
else:
mesh_mode = "Quad"
if PolyCount == "4K-Quad":
quality = "extra-low"
elif PolyCount == "8K-Quad":
quality = "low"
elif PolyCount == "18K-Quad":
quality = "medium"
elif PolyCount == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality
def DownLoadFiles(self, Url_List):
Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
os.makedirs(Save_path, exist_ok=True)
model_file_path = None
for Item in Url_List.list:
url = Item.url
file_name = Item.name
file_path = os.path.join(Save_path, file_name)
if file_path.endswith(".glb"):
model_file_path = file_path
logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
max_retries = 5
for attempt in range(max_retries):
try:
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(file_path, "wb") as f:
shutil.copyfileobj(r.raw, f)
break
except Exception as e:
logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
if attempt < max_retries - 1:
logging.info("Retrying...")
time.sleep(2)
else:
logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
return model_file_path
class Rodin3D_Regular(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Regular"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Detail(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Detail"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Smooth(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
**COMMON_PARAMETERS
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
**kwargs
):
tier = "Smooth"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
class Rodin3D_Sketch(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed":
(
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
Images,
Seed,
**kwargs
):
tier = "Sketch"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
material_type = "PBR"
quality = "medium"
mesh_mode = "Quad"
task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
self.poll_for_task_status(subscription_key, **kwargs)
Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
model = self.DownLoadFiles(Download_List)
return (model,)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {
"Rodin3D_Regular": Rodin3D_Regular,
"Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch,
}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
}

View File

@@ -0,0 +1,635 @@
"""Runway API Nodes
API Docs:
- https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
User Guides:
- https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
- https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
- https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
- https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
"""
from typing import Union, Optional, Any
from enum import Enum
import torch
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
RunwayTaskStatusResponse as TaskStatusResponse,
RunwayTaskStatusEnum as TaskStatus,
RunwayModelEnum as Model,
RunwayDurationEnum as Duration,
RunwayAspectRatioEnum as AspectRatio,
RunwayPromptImageObject,
RunwayPromptImageDetailedObject,
RunwayTextToImageRequest,
RunwayTextToImageResponse,
Model4,
ReferenceImage,
RunwayTextToImageAspectRatioEnum,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_video_output,
image_tensor_pair_to_batch,
validate_string,
download_url_to_image_tensor,
)
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy_api.input_impl import VideoFromFile
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
AVERAGE_DURATION_I2V_SECONDS = 64
AVERAGE_DURATION_FLF_SECONDS = 256
AVERAGE_DURATION_T2I_SECONDS = 41
class RunwayApiError(Exception):
"""Base exception for Runway API errors."""
pass
class RunwayGen4TurboAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
field_1280_720 = "1280:720"
field_720_1280 = "720:1280"
field_1104_832 = "1104:832"
field_832_1104 = "832:1104"
field_960_960 = "960:960"
field_1584_672 = "1584:672"
class RunwayGen3aAspectRatio(str, Enum):
"""Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
field_768_1280 = "768:1280"
field_1280_768 = "1280:768"
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the video URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
# TODO: replace with updated image validation utils (upstream)
def validate_input_image(image: torch.Tensor) -> bool:
"""
Validate the input image is within the size limits for the Runway API.
See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
"""
return image.shape[2] < 8000 and image.shape[1] < 8000
def poll_until_finished(
auth_kwargs: dict[str, str],
api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
estimated_duration: Optional[int] = None,
node_id: Optional[str] = None,
) -> TaskStatusResponse:
"""Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
return PollingOperation(
poll_endpoint=api_endpoint,
completed_statuses=[
TaskStatus.SUCCEEDED.value,
],
failed_statuses=[
TaskStatus.FAILED.value,
TaskStatus.CANCELLED.value,
],
status_extractor=lambda response: (response.status.value),
auth_kwargs=auth_kwargs,
result_url_extractor=get_video_url_from_task_status,
estimated_duration=estimated_duration,
node_id=node_id,
progress_extractor=extract_progress_from_task_status,
).execute()
def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
"""Returns the image URL from the task status response if it exists."""
if response.output and len(response.output) > 0:
return response.output[0]
return None
class RunwayVideoGenNode(ComfyNodeABC):
"""Runway Video Node Base."""
RETURN_TYPES = ("VIDEO",)
FUNCTION = "api_call"
CATEGORY = "api node/video/Runway"
API_NODE = True
def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no video data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
def generate_video(
self,
request: RunwayImageToVideoRequest,
auth_kwargs: dict[str, str],
node_id: Optional[str] = None,
) -> tuple[VideoFromFile]:
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_IMAGE_TO_VIDEO,
method=HttpMethod.POST,
request_model=RunwayImageToVideoRequest,
response_model=RunwayImageToVideoResponse,
),
request=request,
auth_kwargs=auth_kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
final_response = self.get_response(task_id, auth_kwargs, node_id)
self.validate_response(final_response)
video_url = get_video_url_from_task_status(final_response)
return (download_url_to_video_output(video_url),)
class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen3a Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
"""Runway Image to Video Node using Gen4 Turbo model."""
DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen4TurboAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
# Upload image
download_urls = upload_images_to_comfyapi(
start_frame,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen4_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
)
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayFirstLastFrameNode(RunwayVideoGenNode):
"""Runway First-Last Frame Node."""
DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> RunwayImageToVideoResponse:
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
node_id=node_id,
)
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
),
"start_frame": (
IO.IMAGE,
{"tooltip": "Start frame to be used for the video"},
),
"end_frame": (
IO.IMAGE,
{
"tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
},
),
"duration": model_field_to_node_input(
IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayImageToVideoRequest,
"ratio",
enum_type=RunwayGen3aAspectRatio,
),
"seed": model_field_to_node_input(
IO.INT,
RunwayImageToVideoRequest,
"seed",
control_after_generate=True,
),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"unique_id": "UNIQUE_ID",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
def api_call(
self,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
duration: str,
ratio: str,
seed: int,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[VideoFromFile]:
# Validate inputs
validate_string(prompt, min_length=1)
validate_input_image(start_frame)
validate_input_image(end_frame)
# Upload images
stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
download_urls = upload_images_to_comfyapi(
stacked_input_images,
max_images=2,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 2:
raise RunwayApiError("Failed to upload one or more images to comfy api.")
return self.generate_video(
RunwayImageToVideoRequest(
promptText=prompt,
seed=seed,
model=Model("gen3a_turbo"),
duration=Duration(duration),
ratio=AspectRatio(ratio),
promptImage=RunwayPromptImageObject(
root=[
RunwayPromptImageDetailedObject(
uri=str(download_urls[0]), position="first"
),
RunwayPromptImageDetailedObject(
uri=str(download_urls[1]), position="last"
),
]
),
),
auth_kwargs=kwargs,
node_id=unique_id,
)
class RunwayTextToImageNode(ComfyNodeABC):
"""Runway Text to Image Node."""
RETURN_TYPES = ("IMAGE",)
FUNCTION = "api_call"
CATEGORY = "api node/image/Runway"
API_NODE = True
DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": model_field_to_node_input(
IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
),
"ratio": model_field_to_node_input(
IO.COMBO,
RunwayTextToImageRequest,
"ratio",
enum_type=RunwayTextToImageAspectRatioEnum,
),
},
"optional": {
"reference_image": (
IO.IMAGE,
{"tooltip": "Optional reference image to guide the generation"},
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
"""
Validate the task creation response from the Runway API matches
expected format.
"""
if not bool(response.id):
raise RunwayApiError("Invalid initial response from Runway API.")
return True
def validate_response(self, response: TaskStatusResponse) -> bool:
"""
Validate the successful task status response from the Runway API
matches expected format.
"""
if not response.output or len(response.output) == 0:
raise RunwayApiError(
"Runway task succeeded but no image data found in response."
)
return True
def get_response(
self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return poll_until_finished(
auth_kwargs,
ApiEndpoint(
path=f"{PATH_GET_TASK_STATUS}/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TaskStatusResponse,
),
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
node_id=node_id,
)
def api_call(
self,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
unique_id: Optional[str] = None,
**kwargs,
) -> tuple[torch.Tensor]:
# Validate inputs
validate_string(prompt, min_length=1)
# Prepare reference images if provided
reference_images = None
if reference_image is not None:
validate_input_image(reference_image)
download_urls = upload_images_to_comfyapi(
reference_image,
max_images=1,
mime_type="image/png",
auth_kwargs=kwargs,
)
if len(download_urls) != 1:
raise RunwayApiError("Failed to upload reference image to comfy api.")
reference_images = [ReferenceImage(uri=str(download_urls[0]))]
# Create request
request = RunwayTextToImageRequest(
promptText=prompt,
model=Model4.gen4_image,
ratio=ratio,
referenceImages=reference_images,
)
# Execute initial request
initial_operation = SynchronousOperation(
endpoint=ApiEndpoint(
path=PATH_TEXT_TO_IMAGE,
method=HttpMethod.POST,
request_model=RunwayTextToImageRequest,
response_model=RunwayTextToImageResponse,
),
request=request,
auth_kwargs=kwargs,
)
initial_response = initial_operation.execute()
self.validate_task_created(initial_response)
task_id = initial_response.id
# Poll for completion
final_response = self.get_response(
task_id, auth_kwargs=kwargs, node_id=unique_id
)
self.validate_response(final_response)
# Download and return image
image_url = get_image_url_from_task_status(final_response)
return (download_url_to_image_tensor(image_url),)
NODE_CLASS_MAPPINGS = {
"RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
"RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
"RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
"RunwayTextToImageNode": RunwayTextToImageNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
"RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
"RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
"RunwayTextToImageNode": "Runway Text to Image",
}

View File

@@ -0,0 +1,574 @@
import os
from folder_paths import get_output_directory
from comfy_api_nodes.mapper_utils import model_field_to_node_input
from comfy.comfy_types.node_typing import IO
from comfy_api_nodes.apis import (
TripoOrientation,
TripoModelVersion,
)
from comfy_api_nodes.apis.tripo_api import (
TripoTaskType,
TripoStyle,
TripoFileReference,
TripoFileEmptyReference,
TripoUrlReference,
TripoTaskResponse,
TripoTaskStatus,
TripoTextToModelRequest,
TripoImageToModelRequest,
TripoMultiviewToModelRequest,
TripoTextureModelRequest,
TripoRefineModelRequest,
TripoAnimateRigRequest,
TripoAnimateRetargetRequest,
TripoConvertModelRequest,
)
from comfy_api_nodes.apis.client import (
ApiEndpoint,
HttpMethod,
SynchronousOperation,
PollingOperation,
EmptyRequest,
)
from comfy_api_nodes.apinode_utils import (
upload_images_to_comfyapi,
download_url_to_bytesio,
)
def upload_image_to_tripo(image, **kwargs):
urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
def get_model_url_from_response(response: TripoTaskResponse) -> str:
if response.data is not None:
for key in ["pbr_model", "model", "base_model"]:
if getattr(response.data.output, key, None) is not None:
return getattr(response.data.output, key)
raise RuntimeError(f"Failed to get model url from response: {response}")
def poll_until_finished(
kwargs: dict[str, str],
response: TripoTaskResponse,
) -> tuple[str, str]:
"""Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
if response.code != 0:
raise RuntimeError(f"Failed to generate mesh: {response.error}")
task_id = response.data.task_id
response_poll = PollingOperation(
poll_endpoint=ApiEndpoint(
path=f"/proxy/tripo/v2/openapi/task/{task_id}",
method=HttpMethod.GET,
request_model=EmptyRequest,
response_model=TripoTaskResponse,
),
completed_statuses=[TripoTaskStatus.SUCCESS],
failed_statuses=[
TripoTaskStatus.FAILED,
TripoTaskStatus.CANCELLED,
TripoTaskStatus.UNKNOWN,
TripoTaskStatus.BANNED,
TripoTaskStatus.EXPIRED,
],
status_extractor=lambda x: x.data.status,
auth_kwargs=kwargs,
node_id=kwargs["unique_id"],
result_url_extractor=get_model_url_from_response,
progress_extractor=lambda x: x.data.progress,
).execute()
if response_poll.data.status == TripoTaskStatus.SUCCESS:
url = get_model_url_from_response(response_poll)
bytesio = download_url_to_bytesio(url)
# Save the downloaded model file
model_file = f"tripo_model_{task_id}.glb"
with open(os.path.join(get_output_directory(), model_file), "wb") as f:
f.write(bytesio.getvalue())
return model_file, task_id
raise RuntimeError(f"Failed to generate mesh: {response_poll}")
class TripoTextToModelNode:
"""
Generates 3D models synchronously based on a text prompt using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"prompt": ("STRING", {"multiline": True}),
},
"optional": {
"negative_prompt": ("STRING", {"multiline": True}),
"model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"image_seed": ("INT", {"default": 42}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if not prompt:
raise RuntimeError("Prompt is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextToModelRequest(
type=TripoTaskType.TEXT_TO_MODEL,
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
image_seed=image_seed,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoImageToModelNode:
"""
Generates 3D models synchronously based on a single image using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
"style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
style_enum = None if style == "None" else style
if image is None:
raise RuntimeError("Image is required")
tripo_file = upload_image_to_tripo(image, **kwargs)
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoImageToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoImageToModelRequest(
type=TripoTaskType.IMAGE_TO_MODEL,
file=tripo_file,
model_version=model_version,
style=style_enum,
texture=texture,
pbr=pbr,
model_seed=model_seed,
orientation=orientation,
texture_alignment=texture_alignment,
texture_seed=texture_seed,
texture_quality=texture_quality,
face_limit=face_limit,
auto_size=True,
quad=quad
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoMultiviewToModelNode:
"""
Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
"""
AVERAGE_DURATION = 80
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"image_left": ("IMAGE",),
"image_back": ("IMAGE",),
"image_right": ("IMAGE",),
"model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
"orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"model_seed": ("INT", {"default": 42}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"quad": ("BOOLEAN", {"default": False})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
if image is None:
raise RuntimeError("front image for multiview is required")
images = []
image_dict = {
"image": image,
"image_left": image_left,
"image_back": image_back,
"image_right": image_right
}
if image_left is None and image_back is None and image_right is None:
raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
for image_name in ["image", "image_left", "image_back", "image_right"]:
image_ = image_dict[image_name]
if image_ is not None:
tripo_file = upload_image_to_tripo(image_, **kwargs)
images.append(tripo_file)
else:
images.append(TripoFileEmptyReference())
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoMultiviewToModelRequest,
response_model=TripoTaskResponse,
),
request=TripoMultiviewToModelRequest(
type=TripoTaskType.MULTIVIEW_TO_MODEL,
files=images,
model_version=model_version,
orientation=orientation,
texture=texture,
pbr=pbr,
model_seed=model_seed,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment,
face_limit=face_limit,
quad=quad,
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoTextureNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID",),
},
"optional": {
"texture": ("BOOLEAN", {"default": True}),
"pbr": ("BOOLEAN", {"default": True}),
"texture_seed": ("INT", {"default": 42}),
"texture_quality": (["standard", "detailed"], {"default": "standard"}),
"texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 80
def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoTextureModelRequest,
response_model=TripoTaskResponse,
),
request=TripoTextureModelRequest(
original_model_task_id=model_task_id,
texture=texture,
pbr=pbr,
texture_seed=texture_seed,
texture_quality=texture_quality,
texture_alignment=texture_alignment
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRefineNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model_task_id": ("MODEL_TASK_ID", {
"tooltip": "Must be a v1.4 Tripo model"
}),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
RETURN_NAMES = ("model_file", "model task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 240
def generate_mesh(self, model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoRefineModelRequest,
response_model=TripoTaskResponse,
),
request=TripoRefineModelRequest(
draft_model_task_id=model_task_id
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRigNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID",),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RIG_TASK_ID")
RETURN_NAMES = ("model_file", "rig task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 180
def generate_mesh(self, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRigRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRigRequest(
original_model_task_id=original_model_task_id,
out_format="glb",
spec="tripo"
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoRetargetNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("RIG_TASK_ID",),
"animation": ([
"preset:idle",
"preset:walk",
"preset:climb",
"preset:jump",
"preset:slash",
"preset:shoot",
"preset:hurt",
"preset:fall",
"preset:turn",
],),
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
RETURN_NAMES = ("model_file", "retarget task_id")
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, animation, original_model_task_id, **kwargs):
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoAnimateRetargetRequest,
response_model=TripoTaskResponse,
),
request=TripoAnimateRetargetRequest(
original_model_task_id=original_model_task_id,
animation=animation,
out_format="glb",
bake_animation=True
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
class TripoConversionNode:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
"format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
},
"optional": {
"quad": ("BOOLEAN", {"default": False}),
"face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
"texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
"texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@classmethod
def VALIDATE_INPUTS(cls, input_types):
# The min and max of input1 and input2 are still validated because
# we didn't take `input1` or `input2` as arguments
if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
return True
RETURN_TYPES = ()
FUNCTION = "generate_mesh"
CATEGORY = "api node/3d/Tripo"
API_NODE = True
OUTPUT_NODE = True
AVERAGE_DURATION = 30
def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
if not original_model_task_id:
raise RuntimeError("original_model_task_id is required")
response = SynchronousOperation(
endpoint=ApiEndpoint(
path="/proxy/tripo/v2/openapi/task",
method=HttpMethod.POST,
request_model=TripoConvertModelRequest,
response_model=TripoTaskResponse,
),
request=TripoConvertModelRequest(
original_model_task_id=original_model_task_id,
format=format,
quad=quad if quad else None,
face_limit=face_limit if face_limit != -1 else None,
texture_size=texture_size if texture_size != 4096 else None,
texture_format=texture_format if texture_format != "JPEG" else None
),
auth_kwargs=kwargs,
).execute()
return poll_until_finished(kwargs, response)
NODE_CLASS_MAPPINGS = {
"TripoTextToModelNode": TripoTextToModelNode,
"TripoImageToModelNode": TripoImageToModelNode,
"TripoMultiviewToModelNode": TripoMultiviewToModelNode,
"TripoTextureNode": TripoTextureNode,
"TripoRefineNode": TripoRefineNode,
"TripoRigNode": TripoRigNode,
"TripoRetargetNode": TripoRetargetNode,
"TripoConversionNode": TripoConversionNode,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"TripoTextToModelNode": "Tripo: Text to Model",
"TripoImageToModelNode": "Tripo: Image to Model",
"TripoMultiviewToModelNode": "Tripo: Multiview to Model",
"TripoTextureNode": "Tripo: Texture model",
"TripoRefineNode": "Tripo: Refine Draft model",
"TripoRigNode": "Tripo: Rig model",
"TripoRetargetNode": "Tripo: Retarget rigged model",
"TripoConversionNode": "Tripo: Convert model",
}

View File

@@ -3,6 +3,7 @@ import logging
import base64
import requests
import torch
from typing import Optional
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api.input_impl.video_types import VideoFromFile
@@ -24,6 +25,8 @@ from comfy_api_nodes.apinode_utils import (
tensor_to_base64_string
)
AVERAGE_DURATION_VIDEO_GEN = 32
def convert_image_to_base64(image: torch.Tensor):
if image is None:
return None
@@ -31,6 +34,22 @@ def convert_image_to_base64(image: torch.Tensor):
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
return tensor_to_base64_string(scaled_image)
def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
if (
poll_response.response
and hasattr(poll_response.response, "videos")
and poll_response.response.videos
and len(poll_response.response.videos) > 0
):
video = poll_response.response.videos[0]
else:
return None
if hasattr(video, "gcsUri") and video.gcsUri:
return str(video.gcsUri)
return None
class VeoVideoGenerationNode(ComfyNodeABC):
"""
Generates videos from text prompts using Google's Veo API.
@@ -115,6 +134,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}
@@ -134,6 +154,7 @@ class VeoVideoGenerationNode(ComfyNodeABC):
person_generation="ALLOW",
seed=0,
image=None,
unique_id: Optional[str] = None,
**kwargs,
):
# Prepare the instances for the request
@@ -215,7 +236,10 @@ class VeoVideoGenerationNode(ComfyNodeABC):
operationName=operation_name
),
auth_kwargs=kwargs,
poll_interval=5.0
poll_interval=5.0,
result_url_extractor=get_video_url_from_response,
node_id=unique_id,
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
)
# Execute the polling operation

View File

View File

@@ -0,0 +1,100 @@
import logging
from typing import Optional
import torch
from comfy_api.input.video_types import VideoInput
def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
if len(image.shape) == 4:
return image.shape[1], image.shape[2]
elif len(image.shape) == 3:
return image.shape[0], image.shape[1]
else:
raise ValueError("Invalid image tensor shape.")
def validate_image_dimensions(
image: torch.Tensor,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
height, width = get_image_dimensions(image)
if min_width is not None and width < min_width:
raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Image height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
def validate_image_aspect_ratio(
image: torch.Tensor,
min_aspect_ratio: Optional[float] = None,
max_aspect_ratio: Optional[float] = None,
):
width, height = get_image_dimensions(image)
aspect_ratio = width / height
if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
)
if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
raise ValueError(
f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
)
def validate_video_dimensions(
video: VideoInput,
min_width: Optional[int] = None,
max_width: Optional[int] = None,
min_height: Optional[int] = None,
max_height: Optional[int] = None,
):
try:
width, height = video.get_dimensions()
except Exception as e:
logging.error("Error getting dimensions of video: %s", e)
return
if min_width is not None and width < min_width:
raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
if max_width is not None and width > max_width:
raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
if min_height is not None and height < min_height:
raise ValueError(
f"Video height must be at least {min_height}px, got {height}px"
)
if max_height is not None and height > max_height:
raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
def validate_video_duration(
video: VideoInput,
min_duration: Optional[float] = None,
max_duration: Optional[float] = None,
):
try:
duration = video.get_duration()
except Exception as e:
logging.error("Error getting duration of video: %s", e)
return
epsilon = 0.0001
if min_duration is not None and min_duration - epsilon > duration:
raise ValueError(
f"Video duration must be at least {min_duration}s, got {duration}s"
)
if max_duration is not None and duration > max_duration + epsilon:
raise ValueError(
f"Video duration must be at most {max_duration}s, got {duration}s"
)

76
comfy_extras/nodes_apg.py Normal file
View File

@@ -0,0 +1,76 @@
import torch
def project(v0, v1):
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
return v0_parallel, v0_orthogonal
class APG:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
"norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
"momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "sampling/custom_sampling"
def patch(self, model, eta, norm_threshold, momentum):
running_avg = 0
prev_sigma = None
def pre_cfg_function(args):
nonlocal running_avg, prev_sigma
if len(args["conds_out"]) == 1: return args["conds_out"]
cond = args["conds_out"][0]
uncond = args["conds_out"][1]
sigma = args["sigma"][0]
cond_scale = args["cond_scale"]
if prev_sigma is not None and sigma > prev_sigma:
running_avg = 0
prev_sigma = sigma
guidance = cond - uncond
if momentum != 0:
if not torch.is_tensor(running_avg):
running_avg = guidance
else:
running_avg = momentum * running_avg + guidance
guidance = running_avg
if norm_threshold > 0:
guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale = torch.minimum(
torch.ones_like(guidance_norm),
norm_threshold / guidance_norm
)
guidance = guidance * scale
guidance_parallel, guidance_orthogonal = project(guidance, cond)
modified_guidance = guidance_orthogonal + eta * guidance_parallel
modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
return [modified_cond, uncond] + args["conds_out"][2:]
m = model.clone()
m.set_model_sampler_pre_cfg_function(pre_cfg_function)
return (m,)
NODE_CLASS_MAPPINGS = {
"APG": APG,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"APG": "Adaptive Projected Guidance",
}

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import av
import torchaudio
import torch
import comfy.model_management
@@ -7,7 +8,6 @@ import folder_paths
import os
import io
import json
import struct
import random
import hashlib
import node_helpers
@@ -90,60 +90,118 @@ class VAEDecodeAudio:
return ({"waveform": audio, "sample_rate": 44100}, )
def create_vorbis_comment_block(comment_dict, last_block):
vendor_string = b'ComfyUI'
vendor_length = len(vendor_string)
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
comments = []
for key, value in comment_dict.items():
comment = f"{key}={value}".encode('utf-8')
comments.append(struct.pack('<I', len(comment)) + comment)
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
user_comment_list_length = len(comments)
user_comments = b''.join(comments)
# Prepare metadata dictionary
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
comment_data = struct.pack('<I', vendor_length) + vendor_string + struct.pack('<I', user_comment_list_length) + user_comments
if last_block:
id = b'\x84'
else:
id = b'\x04'
comment_block = id + struct.pack('>I', len(comment_data))[1:] + comment_data
# Opus supported sample rates
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
return comment_block
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
output_path = os.path.join(full_output_folder, file)
def insert_or_replace_vorbis_comment(flac_io, comment_dict):
if len(comment_dict) == 0:
return flac_io
# Use original sample rate initially
sample_rate = audio["sample_rate"]
flac_io.seek(4)
# Handle Opus sample rate requirements
if format == "opus":
if sample_rate > 48000:
sample_rate = 48000
elif sample_rate not in OPUS_RATES:
# Find the next highest supported rate
for rate in sorted(OPUS_RATES):
if rate > sample_rate:
sample_rate = rate
break
if sample_rate not in OPUS_RATES: # Fallback if still not supported
sample_rate = 48000
blocks = []
last_block = False
# Resample if necessary
if sample_rate != audio["sample_rate"]:
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
while not last_block:
header = flac_io.read(4)
last_block = (header[0] & 0x80) != 0
block_type = header[0] & 0x7F
block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
block_data = flac_io.read(block_length)
# Create in-memory WAV buffer
wav_buffer = io.BytesIO()
torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
wav_buffer.seek(0) # Rewind for reading
if block_type == 4 or block_type == 1:
pass
else:
header = bytes([(header[0] & (~0x80))]) + header[1:]
blocks.append(header + block_data)
# Use PyAV to convert and add metadata
input_container = av.open(wav_buffer)
blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
# Create output with specified format
output_buffer = io.BytesIO()
output_container = av.open(output_buffer, mode='w', format=format)
new_flac_io = io.BytesIO()
new_flac_io.write(b'fLaC')
for block in blocks:
new_flac_io.write(block)
# Set metadata on the container
for key, value in metadata.items():
output_container.metadata[key] = value
new_flac_io.write(flac_io.read())
return new_flac_io
# Set up the output stream with appropriate properties
input_container.streams.audio[0]
if format == "opus":
out_stream = output_container.add_stream("libopus", rate=sample_rate)
if quality == "64k":
out_stream.bit_rate = 64000
elif quality == "96k":
out_stream.bit_rate = 96000
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "192k":
out_stream.bit_rate = 192000
elif quality == "320k":
out_stream.bit_rate = 320000
elif format == "mp3":
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
if quality == "V0":
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
out_stream.codec_context.qscale = 1
elif quality == "128k":
out_stream.bit_rate = 128000
elif quality == "320k":
out_stream.bit_rate = 320000
else: #format == "flac":
out_stream = output_container.add_stream("flac", rate=sample_rate)
# Copy frames from input to output
for frame in input_container.decode(audio=0):
frame.pts = None # Let PyAV handle timestamps
output_container.mux(out_stream.encode(frame))
# Flush encoder
output_container.mux(out_stream.encode(None))
# Close containers
output_container.close()
input_container.close()
# Write the output to file
output_buffer.seek(0)
with open(output_path, 'wb') as f:
f.write(output_buffer.getbuffer())
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
return { "ui": { "audio": results } }
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -153,50 +211,70 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_audio"
FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
results: list[FileLocator] = []
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
metadata = {}
if not args.disable_metadata:
if prompt is not None:
metadata["prompt"] = json.dumps(prompt)
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata[x] = json.dumps(extra_pnginfo[x])
class SaveAudioMP3:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.flac"
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
buff = io.BytesIO()
torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
RETURN_TYPES = ()
FUNCTION = "save_mp3"
buff = insert_or_replace_vorbis_comment(buff, metadata)
OUTPUT_NODE = True
with open(os.path.join(full_output_folder, file), 'wb') as f:
f.write(buff.getbuffer())
CATEGORY = "audio"
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
counter += 1
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
return { "ui": { "audio": results } }
class SaveAudioOpus:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
self.type = "output"
self.prefix_append = ""
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_opus"
OUTPUT_NODE = True
CATEGORY = "audio"
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
@@ -248,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
"SaveAudioMP3": SaveAudioMP3,
"SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"EmptyLatentAudio": "Empty Latent Audio",
"VAEEncodeAudio": "VAE Encode Audio",
"VAEDecodeAudio": "VAE Decode Audio",
"PreviewAudio": "Preview Audio",
"LoadAudio": "Load Audio",
"SaveAudio": "Save Audio (FLAC)",
"SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)",
}

View File

@@ -0,0 +1,218 @@
import nodes
import torch
import numpy as np
from einops import rearrange
import comfy.model_management
MAX_RESOLUTION = nodes.MAX_RESOLUTION
CAMERA_DICT = {
"base_T_norm": 1.5,
"base_angle": np.pi/3,
"Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
"Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
"Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
"Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
"Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
"Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
"Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
"Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
"ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
}
def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
def get_relative_pose(cam_params):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
cam_to_origin = 0
target_cam_c2w = np.array([
[1, 0, 0, 0],
[0, 1, 0, -cam_to_origin],
[0, 0, 1, 0],
[0, 0, 0, 1]
])
abs2rel = target_cam_c2w @ abs_w2cs[0]
ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses
"""Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
cam_params = [Camera(cam_param) for cam_param in cam_params]
sample_wh_ratio = width / height
pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
if pose_wh_ratio > sample_wh_ratio:
resized_ori_w = height * pose_wh_ratio
for cam_param in cam_params:
cam_param.fx = resized_ori_w * cam_param.fx / width
else:
resized_ori_h = width / pose_wh_ratio
for cam_param in cam_params:
cam_param.fy = resized_ori_h * cam_param.fy / height
intrinsic = np.asarray([[cam_param.fx * width,
cam_param.fy * height,
cam_param.cx * width,
cam_param.cy * height]
for cam_param in cam_params], dtype=np.float32)
K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
plucker_embedding = plucker_embedding[None]
plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
return plucker_embedding
class Camera(object):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
def __init__(self, entry):
fx, fy, cx, cy = entry[1:5]
self.fx = fx
self.fy = fy
self.cx = cx
self.cy = cy
c2w_mat = np.array(entry[7:]).reshape(4, 4)
self.c2w_mat = c2w_mat
self.w2c_mat = np.linalg.inv(c2w_mat)
def ray_condition(K, c2w, H, W, device):
"""Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
"""
# c2w: B, V, 4, 4
# K: B, V, 4
B = K.shape[0]
j, i = torch.meshgrid(
torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
indexing='ij'
)
i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
zs = torch.ones_like(i) # [B, HxW]
xs = (i - cx) / fx * zs
ys = (j - cy) / fy * zs
zs = zs.expand_as(ys)
directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
rays_o = c2w[..., :3, 3] # B, V, 3
rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
# c2w @ dirctions
rays_dxo = torch.cross(rays_o, rays_d)
plucker = torch.cat([rays_dxo, rays_d], dim=-1)
plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
# plucker = plucker.permute(0, 1, 4, 2, 3)
return plucker
def get_camera_motion(angle, T, speed, n=81):
def compute_R_form_rad_angle(angles):
theta_x, theta_y, theta_z = angles
Rx = np.array([[1, 0, 0],
[0, np.cos(theta_x), -np.sin(theta_x)],
[0, np.sin(theta_x), np.cos(theta_x)]])
Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
[0, 1, 0],
[-np.sin(theta_y), 0, np.cos(theta_y)]])
Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
[np.sin(theta_z), np.cos(theta_z), 0],
[0, 0, 1]])
R = np.dot(Rz, np.dot(Ry, Rx))
return R
RT = []
for i in range(n):
_angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
R = compute_R_form_rad_angle(_angle)
_T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
_RT = np.concatenate([R,_T], axis=1)
RT.append(_RT)
RT = np.stack(RT)
return RT
class WanCameraEmbedding:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
"width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
},
"optional":{
"speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
"fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
"cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
"cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
}
}
RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
RETURN_NAMES = ("camera_embedding","width","height","length")
FUNCTION = "run"
CATEGORY = "camera"
def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
"""
Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
"""
motion_list = [camera_pose]
speed = speed
angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
T = np.array(CAMERA_DICT[motion_list[0]]["T"])
RT = get_camera_motion(angle, T, speed, length)
trajs=[]
for cp in RT.tolist():
traj=[fx,fy,cx,cy,0,0]
traj.extend(cp[0])
traj.extend(cp[1])
traj.extend(cp[2])
traj.extend([0,0,0,1])
trajs.append(traj)
cam_params = np.array([[float(x) for x in pose] for pose in trajs])
cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
control_camera_video = process_pose_params(cam_params, width=width, height=height)
control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
control_camera_video = torch.concat(
[
torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
control_camera_video[:, :, 1:]
], dim=2
).transpose(1, 2)
# Reshape, transpose, and view into desired shape
b, f, c, h, w = control_camera_video.shape
control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
return (control_camera_video, width, height, length)
NODE_CLASS_MAPPINGS = {
"WanCameraEmbedding": WanCameraEmbedding,
}

View File

@@ -31,6 +31,7 @@ class T5TokenizerOptions:
}
}
CATEGORY = "_for_testing/conditioning"
RETURN_TYPES = ("CLIP",)
FUNCTION = "set_options"

View File

@@ -77,7 +77,7 @@ class HunyuanImageToVideo:
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
"guidance_type": (["v1 (concat)", "v2 (replace)"], )
"guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
},
"optional": {"start_image": ("IMAGE", ),
}}
@@ -101,10 +101,12 @@ class HunyuanImageToVideo:
if guidance_type == "v1 (concat)":
cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
else:
elif guidance_type == "v2 (replace)":
cond = {'guiding_frame_index': 0}
latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
out_latent["noise_mask"] = mask
elif guidance_type == "custom":
cond = {"ref_latent": concat_latent_image}
positive = node_helpers.conditioning_set_values(positive, cond)

View File

@@ -13,6 +13,8 @@ import os
import re
from io import BytesIO
from inspect import cleandoc
import torch
import comfy.utils
from comfy.comfy_types import FileLocator
@@ -74,6 +76,24 @@ class ImageFromBatch:
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class ImageAddNoise:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True, "tooltip": "The random seed used for creating the noise."}),
"strength": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image"
def repeat(self, image, seed, strength):
generator = torch.manual_seed(seed)
s = torch.clip((image + strength * torch.randn(image.size(), generator=generator, device="cpu").to(image)), min=0.0, max=1.0)
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -210,6 +230,186 @@ class SVG:
all_svgs_list.extend(svg_item.data)
return SVG(all_svgs_list)
class ImageStitch:
"""Upstreamed from https://github.com/kijai/ComfyUI-KJNodes"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image1": ("IMAGE",),
"direction": (["right", "down", "left", "up"], {"default": "right"}),
"match_image_size": ("BOOLEAN", {"default": True}),
"spacing_width": (
"INT",
{"default": 0, "min": 0, "max": 1024, "step": 2},
),
"spacing_color": (
["white", "black", "red", "green", "blue"],
{"default": "white"},
),
},
"optional": {
"image2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stitch"
CATEGORY = "image/transform"
DESCRIPTION = """
Stitches image2 to image1 in the specified direction.
If image2 is not provided, returns image1 unchanged.
Optional spacing can be added between images.
"""
def stitch(
self,
image1,
direction,
match_image_size,
spacing_width,
spacing_color,
image2=None,
):
if image2 is None:
return (image1,)
# Handle batch size differences
if image1.shape[0] != image2.shape[0]:
max_batch = max(image1.shape[0], image2.shape[0])
if image1.shape[0] < max_batch:
image1 = torch.cat(
[image1, image1[-1:].repeat(max_batch - image1.shape[0], 1, 1, 1)]
)
if image2.shape[0] < max_batch:
image2 = torch.cat(
[image2, image2[-1:].repeat(max_batch - image2.shape[0], 1, 1, 1)]
)
# Match image sizes if requested
if match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
aspect_ratio = w2 / h2
if direction in ["left", "right"]:
target_h, target_w = h1, int(h1 * aspect_ratio)
else: # up, down
target_w, target_h = w1, int(w1 / aspect_ratio)
image2 = comfy.utils.common_upscale(
image2.movedim(-1, 1), target_w, target_h, "lanczos", "disabled"
).movedim(1, -1)
# When not matching sizes, pad to align non-concat dimensions
if not match_image_size:
h1, w1 = image1.shape[1:3]
h2, w2 = image2.shape[1:3]
if direction in ["left", "right"]:
# For horizontal concat, pad heights to match
if h1 != h2:
target_h = max(h1, h2)
if h1 < target_h:
pad_h = target_h - h1
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image1 = torch.nn.functional.pad(image1, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
if h2 < target_h:
pad_h = target_h - h2
pad_top, pad_bottom = pad_h // 2, pad_h - pad_h // 2
image2 = torch.nn.functional.pad(image2, (0, 0, 0, 0, pad_top, pad_bottom), mode='constant', value=0.0)
else: # up, down
# For vertical concat, pad widths to match
if w1 != w2:
target_w = max(w1, w2)
if w1 < target_w:
pad_w = target_w - w1
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image1 = torch.nn.functional.pad(image1, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
if w2 < target_w:
pad_w = target_w - w2
pad_left, pad_right = pad_w // 2, pad_w - pad_w // 2
image2 = torch.nn.functional.pad(image2, (0, 0, pad_left, pad_right), mode='constant', value=0.0)
# Ensure same number of channels
if image1.shape[-1] != image2.shape[-1]:
max_channels = max(image1.shape[-1], image2.shape[-1])
if image1.shape[-1] < max_channels:
image1 = torch.cat(
[
image1,
torch.ones(
*image1.shape[:-1],
max_channels - image1.shape[-1],
device=image1.device,
),
],
dim=-1,
)
if image2.shape[-1] < max_channels:
image2 = torch.cat(
[
image2,
torch.ones(
*image2.shape[:-1],
max_channels - image2.shape[-1],
device=image2.device,
),
],
dim=-1,
)
# Add spacing if specified
if spacing_width > 0:
spacing_width = spacing_width + (spacing_width % 2) # Ensure even
color_map = {
"white": 1.0,
"black": 0.0,
"red": (1.0, 0.0, 0.0),
"green": (0.0, 1.0, 0.0),
"blue": (0.0, 0.0, 1.0),
}
color_val = color_map[spacing_color]
if direction in ["left", "right"]:
spacing_shape = (
image1.shape[0],
max(image1.shape[1], image2.shape[1]),
spacing_width,
image1.shape[-1],
)
else:
spacing_shape = (
image1.shape[0],
spacing_width,
max(image1.shape[2], image2.shape[2]),
image1.shape[-1],
)
spacing = torch.full(spacing_shape, 0.0, device=image1.device)
if isinstance(color_val, tuple):
for i, c in enumerate(color_val):
if i < spacing.shape[-1]:
spacing[..., i] = c
if spacing.shape[-1] == 4: # Add alpha
spacing[..., 3] = 1.0
else:
spacing[..., : min(3, spacing.shape[-1])] = color_val
if spacing.shape[-1] == 4:
spacing[..., 3] = 1.0
# Concatenate images
images = [image2, image1] if direction in ["left", "up"] else [image1, image2]
if spacing_width > 0:
images.insert(1, spacing)
concat_dim = 2 if direction in ["left", "right"] else 1
return (torch.cat(images, dim=concat_dim),)
class SaveSVGNode:
"""
Save SVG files on disk.
@@ -295,7 +495,9 @@ NODE_CLASS_MAPPINGS = {
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,
"ImageAddNoise": ImageAddNoise,
"SaveAnimatedWEBP": SaveAnimatedWEBP,
"SaveAnimatedPNG": SaveAnimatedPNG,
"SaveSVGNode": SaveSVGNode,
"ImageStitch": ImageStitch,
}

View File

@@ -2,6 +2,10 @@ import nodes
import folder_paths
import os
from comfy.comfy_types import IO
from comfy_api.input_impl import VideoFromFile
def normalize_path(path):
return path.replace('\\', '/')
@@ -12,7 +16,7 @@ class Load3D():
os.makedirs(input_dir, exist_ok=True)
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.mtl', '.fbx', '.stl'))]
files = [normalize_path(os.path.join("3d", f)) for f in os.listdir(input_dir) if f.endswith(('.gltf', '.glb', '.obj', '.fbx', '.stl'))]
return {"required": {
"model_file": (sorted(files), {"file_upload": True}),
@@ -21,8 +25,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -41,7 +45,14 @@ class Load3D():
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video
class Load3DAnimation():
@classmethod
@@ -59,8 +70,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}}
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA")
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info")
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO)
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video")
FUNCTION = "process"
EXPERIMENTAL = True
@@ -77,7 +88,14 @@ class Load3DAnimation():
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, normal_image, image['camera_info']
video = None
if image['recording'] != "":
recording_video_path = folder_paths.get_annotated_filepath(image['recording'])
video = VideoFromFile(recording_video_path)
return output_image, output_mask, model_file, normal_image, image['camera_info'], video
class Preview3D():
@classmethod

View File

@@ -0,0 +1,360 @@
import re
from comfy.comfy_types.node_typing import IO
class StringConcatenate():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"delimiter": (IO.STRING, {"multiline": False, "default": ""})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, delimiter, **kwargs):
return delimiter.join((string_a, string_b)),
class StringSubstring():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"start": (IO.INT, {}),
"end": (IO.INT, {}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, start, end, **kwargs):
return string[start:end],
class StringLength():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.INT,)
RETURN_NAMES = ("length",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, **kwargs):
length = len(string)
return length,
class CaseConverter():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["UPPERCASE", "lowercase", "Capitalize", "Title Case"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "UPPERCASE":
result = string.upper()
elif mode == "lowercase":
result = string.lower()
elif mode == "Capitalize":
result = string.capitalize()
elif mode == "Title Case":
result = string.title()
else:
result = string
return result,
class StringTrim():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Both", "Left", "Right"]})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, mode, **kwargs):
if mode == "Both":
result = string.strip()
elif mode == "Left":
result = string.lstrip()
elif mode == "Right":
result = string.rstrip()
else:
result = string
return result,
class StringReplace():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"find": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, find, replace, **kwargs):
result = string.replace(find, replace)
return result,
class StringContains():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"substring": (IO.STRING, {"multiline": True}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("contains",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, substring, case_sensitive, **kwargs):
if case_sensitive:
contains = substring in string
else:
contains = substring.lower() in string.lower()
return contains,
class StringCompare():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string_a": (IO.STRING, {"multiline": True}),
"string_b": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["Starts With", "Ends With", "Equal"]}),
"case_sensitive": (IO.BOOLEAN, {"default": True})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string_a, string_b, mode, case_sensitive, **kwargs):
if case_sensitive:
a = string_a
b = string_b
else:
a = string_a.lower()
b = string_b.lower()
if mode == "Equal":
return a == b,
elif mode == "Starts With":
return a.startswith(b),
elif mode == "Ends With":
return a.endswith(b),
class RegexMatch():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False})
}
}
RETURN_TYPES = (IO.BOOLEAN,)
RETURN_NAMES = ("matches",)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, case_insensitive, multiline, dotall, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
match = re.search(regex_pattern, string, flags)
result = match is not None
except re.error:
result = False
return result,
class RegexExtract():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"mode": (IO.COMBO, {"options": ["First Match", "All Matches", "First Group", "All Groups"]}),
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False}),
"group_index": (IO.INT, {"default": 1, "min": 0, "max": 100})
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, mode, case_insensitive, multiline, dotall, group_index, **kwargs):
join_delimiter = "\n"
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
try:
if mode == "First Match":
match = re.search(regex_pattern, string, flags)
if match:
result = match.group(0)
else:
result = ""
elif mode == "All Matches":
matches = re.findall(regex_pattern, string, flags)
if matches:
if isinstance(matches[0], tuple):
result = join_delimiter.join([m[0] for m in matches])
else:
result = join_delimiter.join(matches)
else:
result = ""
elif mode == "First Group":
match = re.search(regex_pattern, string, flags)
if match and len(match.groups()) >= group_index:
result = match.group(group_index)
else:
result = ""
elif mode == "All Groups":
matches = re.finditer(regex_pattern, string, flags)
results = []
for match in matches:
if match.groups() and len(match.groups()) >= group_index:
results.append(match.group(group_index))
result = join_delimiter.join(results)
else:
result = ""
except re.error:
result = ""
return result,
class RegexReplace():
DESCRIPTION = "Find and replace text using regex patterns."
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"string": (IO.STRING, {"multiline": True}),
"regex_pattern": (IO.STRING, {"multiline": True}),
"replace": (IO.STRING, {"multiline": True}),
},
"optional": {
"case_insensitive": (IO.BOOLEAN, {"default": True}),
"multiline": (IO.BOOLEAN, {"default": False}),
"dotall": (IO.BOOLEAN, {"default": False, "tooltip": "When enabled, the dot (.) character will match any character including newline characters. When disabled, dots won't match newlines."}),
"count": (IO.INT, {"default": 0, "min": 0, "max": 100, "tooltip": "Maximum number of replacements to make. Set to 0 to replace all occurrences (default). Set to 1 to replace only the first match, 2 for the first two matches, etc."}),
}
}
RETURN_TYPES = (IO.STRING,)
FUNCTION = "execute"
CATEGORY = "utils/string"
def execute(self, string, regex_pattern, replace, case_insensitive=True, multiline=False, dotall=False, count=0, **kwargs):
flags = 0
if case_insensitive:
flags |= re.IGNORECASE
if multiline:
flags |= re.MULTILINE
if dotall:
flags |= re.DOTALL
result = re.sub(regex_pattern, replace, string, count=count, flags=flags)
return result,
NODE_CLASS_MAPPINGS = {
"StringConcatenate": StringConcatenate,
"StringSubstring": StringSubstring,
"StringLength": StringLength,
"CaseConverter": CaseConverter,
"StringTrim": StringTrim,
"StringReplace": StringReplace,
"StringContains": StringContains,
"StringCompare": StringCompare,
"RegexMatch": RegexMatch,
"RegexExtract": RegexExtract,
"RegexReplace": RegexReplace,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"StringConcatenate": "Concatenate",
"StringSubstring": "Substring",
"StringLength": "Length",
"CaseConverter": "Case Converter",
"StringTrim": "Trim",
"StringReplace": "Replace",
"StringContains": "Contains",
"StringCompare": "Compare",
"RegexMatch": "Regex Match",
"RegexExtract": "Regex Extract",
"RegexReplace": "Regex Replace",
}

View File

@@ -1,4 +1,5 @@
import torch
from comfy_api.torch_helpers import set_torch_compile_wrapper
class TorchCompileModel:
@classmethod
@@ -14,7 +15,7 @@ class TorchCompileModel:
def patch(self, model, backend):
m = model.clone()
m.add_object_patch("diffusion_model", torch.compile(model=m.get_model_object("diffusion_model"), backend=backend))
set_torch_compile_wrapper(model=m, backend=backend)
return (m, )
NODE_CLASS_MAPPINGS = {

View File

@@ -268,8 +268,9 @@ class WanVaceToVideo:
trim_latent = reference_image.shape[2]
mask = mask.unsqueeze(0)
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": control_video_latent, "vace_mask": mask, "vace_strength": strength})
positive = node_helpers.conditioning_set_values(positive, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"vace_frames": [control_video_latent], "vace_mask": [mask], "vace_strength": [strength]}, append=True)
latent = torch.zeros([batch_size, 16, latent_length, height // 8, width // 8], device=comfy.model_management.intermediate_device())
out_latent = {}
@@ -297,6 +298,90 @@ class TrimVideoLatent:
samples_out["samples"] = s1[:, :, trim_amount:]
return (samples_out,)
class WanCameraImageToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"clip_vision_output": ("CLIP_VISION_OUTPUT", ),
"start_image": ("IMAGE", ),
"camera_conditions": ("WAN_CAMERA_EMBEDDING", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None, camera_conditions=None):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
concat_latent = comfy.latent_formats.Wan21().process_out(concat_latent)
if start_image is not None:
start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
concat_latent_image = vae.encode(start_image[:, :, :, :3])
concat_latent[:,:,:concat_latent_image.shape[2]] = concat_latent_image[:,:,:concat_latent.shape[2]]
positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent})
negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent})
if camera_conditions is not None:
positive = node_helpers.conditioning_set_values(positive, {'camera_conditions': camera_conditions})
negative = node_helpers.conditioning_set_values(negative, {'camera_conditions': camera_conditions})
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
out_latent = {}
out_latent["samples"] = latent
return (positive, negative, out_latent)
class WanPhantomSubjectToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE", ),
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
},
"optional": {"images": ("IMAGE", ),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
FUNCTION = "encode"
CATEGORY = "conditioning/video_models"
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond2 = negative
if images is not None:
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
latent_images = []
for i in images:
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
concat_latent_image = torch.cat(latent_images, dim=2)
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
out_latent = {}
out_latent["samples"] = latent
return (positive, cond2, negative, out_latent)
NODE_CLASS_MAPPINGS = {
"WanImageToVideo": WanImageToVideo,
@@ -305,4 +390,6 @@ NODE_CLASS_MAPPINGS = {
"WanFirstLastFrameToVideo": WanFirstLastFrameToVideo,
"WanVaceToVideo": WanVaceToVideo,
"TrimVideoLatent": TrimVideoLatent,
"WanCameraImageToVideo": WanCameraImageToVideo,
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
}

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.3.34"
__version__ = "0.3.39"

View File

@@ -909,7 +909,6 @@ class PromptQueue:
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self
def put(self, item):
with self.mutex:
@@ -954,6 +953,7 @@ class PromptQueue:
self.history[prompt[1]].update(history_result)
self.server.queue_updated()
# Note: slow
def get_current_queue(self):
with self.mutex:
out = []
@@ -961,6 +961,13 @@ class PromptQueue:
out += [x]
return (out, copy.deepcopy(self.queue))
# read-safe as long as queue items are immutable
def get_current_queue_volatile(self):
with self.mutex:
running = [x for x in self.currently_running.values()]
queued = copy.copy(self.queue)
return (running, queued)
def get_tasks_remaining(self):
with self.mutex:
return len(self.queue) + len(self.currently_running)

View File

@@ -1,28 +0,0 @@
import importlib.util
import shutil
import os
import ctypes
import logging
def fix_pytorch_libomp():
"""
Fix PyTorch libomp DLL issue on Windows by copying the correct DLL file if needed.
"""
torch_spec = importlib.util.find_spec("torch")
for folder in torch_spec.submodule_search_locations:
lib_folder = os.path.join(folder, "lib")
test_file = os.path.join(lib_folder, "fbgemm.dll")
dest = os.path.join(lib_folder, "libomp140.x86_64.dll")
if os.path.exists(dest):
break
with open(test_file, "rb") as f:
contents = f.read()
if b"libomp140.x86_64.dll" not in contents:
break
try:
ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

10
main.py
View File

@@ -125,13 +125,6 @@ if __name__ == "__main__":
import cuda_malloc
if args.windows_standalone_build:
try:
from fix_torch import fix_pytorch_libomp
fix_pytorch_libomp()
except:
pass
import comfy.utils
import execution
@@ -267,7 +260,6 @@ def start_comfyui(asyncio_loop=None):
asyncio_loop = asyncio.new_event_loop()
asyncio.set_event_loop(asyncio_loop)
prompt_server = server.PromptServer(asyncio_loop)
q = execution.PromptQueue(prompt_server)
hook_breaker_ac10a0.save_functions()
nodes.init_extra_nodes(init_custom_nodes=not args.disable_all_custom_nodes, init_api_nodes=not args.disable_api_nodes)
@@ -278,7 +270,7 @@ def start_comfyui(asyncio_loop=None):
prompt_server.add_routes()
hijack_progress(prompt_server)
threading.Thread(target=prompt_worker, daemon=True, args=(q, prompt_server,)).start()
threading.Thread(target=prompt_worker, daemon=True, args=(prompt_server.prompt_queue, prompt_server,)).start()
if args.quick_test_for_ci:
exit(0)

View File

@@ -5,12 +5,18 @@ from comfy.cli_args import args
from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}):
def conditioning_set_values(conditioning, values={}, append=False):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
val = values[k]
if append:
old_val = n[1].get(k, None)
if old_val is not None:
val = old_val + val
n[1][k] = val
c.append(n)
return c

View File

@@ -1103,16 +1103,7 @@ class unCLIPConditioning:
if strength == 0:
return (conditioning, )
c = []
for t in conditioning:
o = t[1].copy()
x = {"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}
if "unclip_conditioning" in o:
o["unclip_conditioning"] = o["unclip_conditioning"][:] + [x]
else:
o["unclip_conditioning"] = [x]
n = [t[0], o]
c.append(n)
c = node_helpers.conditioning_set_values(conditioning, {"unclip_conditioning": [{"clip_vision_output": clip_vision_output, "strength": strength, "noise_augmentation": noise_augmentation}]}, append=True)
return (c, )
class GLIGENLoader:
@@ -1940,7 +1931,7 @@ class ImagePadForOutpaint:
mask[top:top + d2, left:left + d3] = t
return (new_image, mask)
return (new_image, mask.unsqueeze(0))
NODE_CLASS_MAPPINGS = {
@@ -2070,6 +2061,7 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ImagePadForOutpaint": "Pad Image for Outpainting",
"ImageBatch": "Batch Images",
"ImageCrop": "Image Crop",
"ImageStitch": "Image Stitch",
"ImageBlend": "Image Blend",
"ImageBlur": "Image Blur",
"ImageQuantize": "Image Quantize",
@@ -2261,8 +2253,11 @@ def init_builtin_extra_nodes():
"nodes_optimalsteps.py",
"nodes_hidream.py",
"nodes_fresca.py",
"nodes_apg.py",
"nodes_preview_any.py",
"nodes_ace.py",
"nodes_string.py",
"nodes_camera_trajectory.py",
]
import_failed = []
@@ -2287,6 +2282,10 @@ def init_builtin_api_nodes():
"nodes_pixverse.py",
"nodes_stability.py",
"nodes_pika.py",
"nodes_runway.py",
"nodes_tripo.py",
"nodes_rodin.py",
"nodes_gemini.py",
]
if not load_custom_node(os.path.join(api_nodes_dir, "canary.py"), module_parent="comfy_api_nodes"):

View File

@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.3.34"
version = "0.3.39"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.9"

View File

@@ -1,5 +1,6 @@
comfyui-frontend-package==1.18.10
comfyui-workflow-templates==0.1.14
comfyui-frontend-package==1.21.6
comfyui-workflow-templates==0.1.25
comfyui-embedded-docs==0.2.0
torch
torchsde
torchvision

View File

@@ -101,6 +101,14 @@ prompt_text = """
def queue_prompt(prompt):
p = {"prompt": prompt}
# If the workflow contains API nodes, you can add a Comfy API key to the `extra_data`` field of the payload.
# p["extra_data"] = {
# "api_key_comfy_org": "comfyui-87d01e28d*******************************************************" # replace with real key
# }
# See: https://docs.comfy.org/tutorials/api-nodes/overview
# Generate a key here: https://platform.comfy.org/login
data = json.dumps(p).encode('utf-8')
req = request.Request("http://127.0.0.1:8188/prompt", data=data)
request.urlopen(req)

View File

@@ -29,6 +29,7 @@ import comfy.model_management
import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
@@ -159,7 +160,7 @@ class PromptServer():
self.custom_node_manager = CustomNodeManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = None
self.prompt_queue = execution.PromptQueue(self)
self.loop = loop
self.messages = asyncio.Queue()
self.client_session:Optional[aiohttp.ClientSession] = None
@@ -226,7 +227,7 @@ class PromptServer():
return response
@routes.get("/embeddings")
def get_embeddings(self):
def get_embeddings(request):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@@ -282,7 +283,6 @@ class PromptServer():
a.update(f.read())
b.update(image.file.read())
image.file.seek(0)
f.close()
return a.hexdigest() == b.hexdigest()
return False
@@ -390,7 +390,7 @@ class PromptServer():
async def view_image(request):
if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"]
filename,output_dir = folder_paths.annotated_filepath(filename)
filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename:
return web.Response(status=400)
@@ -476,9 +476,8 @@ class PromptServer():
# Get content type from mimetype, defaulting to 'application/octet-stream'
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
# For security, force certain extensions to download instead of display
file_extension = os.path.splitext(filename)[1].lower()
if file_extension in {'.html', '.htm', '.js', '.css'}:
# For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
content_type = 'application/octet-stream' # Forces download
return web.FileResponse(
@@ -621,7 +620,7 @@ class PromptServer():
@routes.get("/queue")
async def get_queue(request):
queue_info = {}
current_queue = self.prompt_queue.get_current_queue()
current_queue = self.prompt_queue.get_current_queue_volatile()
queue_info['queue_running'] = current_queue[0]
queue_info['queue_pending'] = current_queue[1]
return web.json_response(queue_info)
@@ -746,6 +745,13 @@ class PromptServer():
web.static('/templates', workflow_templates_path)
])
# Serve embedded documentation from the package
embedded_docs_path = FrontendManager.embedded_docs_path()
if embedded_docs_path:
self.app.add_routes([
web.static('/docs', embedded_docs_path)
])
self.app.add_routes([
web.static('/', self.web_root),
])

View File

@@ -0,0 +1,239 @@
import pytest
import torch
import tempfile
import os
import av
import io
from fractions import Fraction
from comfy_api.input_impl.video_types import VideoFromFile, VideoFromComponents
from comfy_api.util.video_types import VideoComponents
from comfy_api.input.basic_types import AudioInput
from av.error import InvalidDataError
EPSILON = 0.0001
@pytest.fixture
def sample_images():
"""3-frame 2x2 RGB video tensor"""
return torch.rand(3, 2, 2, 3)
@pytest.fixture
def sample_audio():
"""Stereo audio with 44.1kHz sample rate"""
return AudioInput(
{
"waveform": torch.rand(1, 2, 1000),
"sample_rate": 44100,
}
)
@pytest.fixture
def video_components(sample_images, sample_audio):
"""VideoComponents with images, audio, and metadata"""
return VideoComponents(
images=sample_images,
audio=sample_audio,
frame_rate=Fraction(30),
metadata={"test": "metadata"},
)
def create_test_video(width=4, height=4, frames=3, fps=30):
"""Helper to create a temporary video file"""
tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
with av.open(tmp.name, mode="w") as container:
stream = container.add_stream("h264", rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
for i in range(frames):
frame = av.VideoFrame.from_ndarray(
torch.ones(height, width, 3, dtype=torch.uint8).numpy() * (i * 85),
format="rgb24",
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
# Flush
packet = stream.encode(None)
container.mux(packet)
return tmp.name
@pytest.fixture
def simple_video_file():
"""4x4 video with 3 frames at 30fps"""
file_path = create_test_video()
yield file_path
os.unlink(file_path)
def test_video_from_components_get_duration(video_components):
"""Duration calculated correctly from frame count and frame rate"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
expected_duration = 3.0 / 30.0
assert duration == pytest.approx(expected_duration)
def test_video_from_components_get_duration_different_frame_rates(sample_images):
"""Duration correct for different frame rates including fractional"""
# Test with 60 fps
components_60fps = VideoComponents(images=sample_images, frame_rate=Fraction(60))
video_60fps = VideoFromComponents(components_60fps)
assert video_60fps.get_duration() == pytest.approx(3.0 / 60.0)
# Test with fractional frame rate (23.976fps)
components_frac = VideoComponents(
images=sample_images, frame_rate=Fraction(24000, 1001)
)
video_frac = VideoFromComponents(components_frac)
expected_frac = 3.0 / (24000.0 / 1001.0)
assert video_frac.get_duration() == pytest.approx(expected_frac)
def test_video_from_components_get_duration_empty_video():
"""Duration is zero for empty video"""
empty_components = VideoComponents(
images=torch.zeros(0, 2, 2, 3), frame_rate=Fraction(30)
)
video = VideoFromComponents(empty_components)
assert video.get_duration() == 0.0
def test_video_from_components_get_dimensions(video_components):
"""Dimensions returned correctly from image tensor shape"""
video = VideoFromComponents(video_components)
width, height = video.get_dimensions()
assert width == 2
assert height == 2
def test_video_from_file_get_duration(simple_video_file):
"""Duration extracted from file metadata"""
video = VideoFromFile(simple_video_file)
duration = video.get_duration()
assert duration == pytest.approx(0.1, abs=0.01)
def test_video_from_file_get_dimensions(simple_video_file):
"""Dimensions read from stream without decoding frames"""
video = VideoFromFile(simple_video_file)
width, height = video.get_dimensions()
assert width == 4
assert height == 4
def test_video_from_file_bytesio_input():
"""VideoFromFile works with BytesIO input"""
buffer = io.BytesIO()
with av.open(buffer, mode="w", format="mp4") as container:
stream = container.add_stream("h264", rate=30)
stream.width = 2
stream.height = 2
stream.pix_fmt = "yuv420p"
frame = av.VideoFrame.from_ndarray(
torch.zeros(2, 2, 3, dtype=torch.uint8).numpy(), format="rgb24"
)
frame = frame.reformat(format="yuv420p")
packet = stream.encode(frame)
container.mux(packet)
packet = stream.encode(None)
container.mux(packet)
buffer.seek(0)
video = VideoFromFile(buffer)
assert video.get_dimensions() == (2, 2)
assert video.get_duration() == pytest.approx(1 / 30, abs=0.01)
def test_video_from_file_invalid_file_error():
"""InvalidDataError raised for non-video files"""
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as tmp:
tmp.write(b"not a video file")
tmp.flush()
tmp_name = tmp.name
try:
with pytest.raises(InvalidDataError):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_video_from_file_audio_only_error():
"""ValueError raised for audio-only files"""
with tempfile.NamedTemporaryFile(suffix=".m4a", delete=False) as tmp:
tmp_name = tmp.name
try:
with av.open(tmp_name, mode="w") as container:
stream = container.add_stream("aac", rate=44100)
stream.sample_rate = 44100
stream.format = "fltp"
audio_data = torch.zeros(1, 1024).numpy()
audio_frame = av.AudioFrame.from_ndarray(
audio_data, format="fltp", layout="mono"
)
audio_frame.sample_rate = 44100
audio_frame.pts = 0
packet = stream.encode(audio_frame)
container.mux(packet)
for packet in stream.encode(None):
container.mux(packet)
with pytest.raises(ValueError, match="No video stream found"):
video = VideoFromFile(tmp_name)
video.get_dimensions()
finally:
os.unlink(tmp_name)
def test_single_frame_video():
"""Single frame video has correct duration"""
components = VideoComponents(
images=torch.rand(1, 10, 10, 3), frame_rate=Fraction(1)
)
video = VideoFromComponents(components)
assert video.get_duration() == 1.0
@pytest.mark.parametrize(
"frame_rate,expected_fps",
[
(Fraction(24000, 1001), 24000 / 1001),
(Fraction(30000, 1001), 30000 / 1001),
(Fraction(25, 1), 25.0),
(Fraction(50, 2), 25.0),
],
)
def test_fractional_frame_rates(frame_rate, expected_fps):
"""Duration calculated correctly for various fractional frame rates"""
components = VideoComponents(images=torch.rand(100, 4, 4, 3), frame_rate=frame_rate)
video = VideoFromComponents(components)
duration = video.get_duration()
expected_duration = 100.0 / expected_fps
assert duration == pytest.approx(expected_duration)
def test_duration_consistency(video_components):
"""get_duration() consistent with manual calculation from components"""
video = VideoFromComponents(video_components)
duration = video.get_duration()
components = video.get_components()
manual_duration = float(components.images.shape[0] / components.frame_rate)
assert duration == pytest.approx(manual_duration)

View File

View File

@@ -0,0 +1,240 @@
import torch
from unittest.mock import patch, MagicMock
# Mock nodes module to prevent CUDA initialization during import
mock_nodes = MagicMock()
mock_nodes.MAX_RESOLUTION = 16384
with patch.dict('sys.modules', {'nodes': mock_nodes}):
from comfy_extras.nodes_images import ImageStitch
class TestImageStitch:
def create_test_image(self, batch_size=1, height=64, width=64, channels=3):
"""Helper to create test images with specific dimensions"""
return torch.rand(batch_size, height, width, channels)
def test_no_image2_passthrough(self):
"""Test that when image2 is None, image1 is returned unchanged"""
node = ImageStitch()
image1 = self.create_test_image()
result = node.stitch(image1, "right", True, 0, "white", image2=None)
assert len(result) == 1
assert torch.equal(result[0], image1)
def test_basic_horizontal_stitch_right(self):
"""Test basic horizontal stitching to the right"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "right", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 32 + 24 width
def test_basic_horizontal_stitch_left(self):
"""Test basic horizontal stitching to the left"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
result = node.stitch(image1, "left", False, 0, "white", image2)
assert result[0].shape == (1, 32, 56, 3) # 24 + 32 width
def test_basic_vertical_stitch_down(self):
"""Test basic vertical stitching downward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "down", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 32 + 24 height
def test_basic_vertical_stitch_up(self):
"""Test basic vertical stitching upward"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
result = node.stitch(image1, "up", False, 0, "white", image2)
assert result[0].shape == (1, 56, 32, 3) # 24 + 32 height
def test_size_matching_horizontal(self):
"""Test size matching for horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32) # Different aspect ratio
result = node.stitch(image1, "right", True, 0, "white", image2)
# image2 should be resized to match image1's height (64) with preserved aspect ratio
expected_width = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, 64, expected_width, 3)
def test_size_matching_vertical(self):
"""Test size matching for vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=64)
image2 = self.create_test_image(height=32, width=32)
result = node.stitch(image1, "down", True, 0, "white", image2)
# image2 should be resized to match image1's width (64) with preserved aspect ratio
expected_height = 64 + 64 # original + resized (32*64/32 = 64)
assert result[0].shape == (1, expected_height, 64, 3)
def test_padding_for_mismatched_heights_horizontal(self):
"""Test padding when heights don't match in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=64, width=32)
image2 = self.create_test_image(height=48, width=24) # Shorter height
result = node.stitch(image1, "right", False, 0, "white", image2)
# Both images should be padded to height 64
assert result[0].shape == (1, 64, 56, 3) # 32 + 24 width, max(64,48) height
def test_padding_for_mismatched_widths_vertical(self):
"""Test padding when widths don't match in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=64)
image2 = self.create_test_image(height=24, width=48) # Narrower width
result = node.stitch(image1, "down", False, 0, "white", image2)
# Both images should be padded to width 64
assert result[0].shape == (1, 56, 64, 3) # 32 + 24 height, max(64,48) width
def test_spacing_horizontal(self):
"""Test spacing addition in horizontal concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=24)
spacing_width = 16
result = node.stitch(image1, "right", False, spacing_width, "white", image2)
# Expected width: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 32, 72, 3)
def test_spacing_vertical(self):
"""Test spacing addition in vertical concatenation"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=24, width=32)
spacing_width = 16
result = node.stitch(image1, "down", False, spacing_width, "white", image2)
# Expected height: 32 + 16 (spacing) + 24 = 72
assert result[0].shape == (1, 72, 32, 3)
def test_spacing_color_values(self):
"""Test that spacing colors are applied correctly"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Test white spacing
result_white = node.stitch(image1, "right", False, 16, "white", image2)
# Check that spacing region contains white values (close to 1.0)
spacing_region = result_white[0][:, :, 32:48, :] # Middle 16 pixels
assert torch.all(spacing_region >= 0.9) # Should be close to white
# Test black spacing
result_black = node.stitch(image1, "right", False, 16, "black", image2)
spacing_region = result_black[0][:, :, 32:48, :]
assert torch.all(spacing_region <= 0.1) # Should be close to black
def test_odd_spacing_width_made_even(self):
"""Test that odd spacing widths are made even"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
# Use odd spacing width
result = node.stitch(image1, "right", False, 15, "white", image2)
# Should be made even (16), so total width = 32 + 16 + 32 = 80
assert result[0].shape == (1, 32, 80, 3)
def test_batch_size_matching(self):
"""Test that different batch sizes are handled correctly"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=32, width=32)
image2 = self.create_test_image(batch_size=1, height=32, width=32)
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should match larger batch size
assert result[0].shape == (2, 32, 64, 3)
def test_channel_matching_rgb_to_rgba(self):
"""Test that channel differences are handled (RGB + alpha)"""
node = ImageStitch()
image1 = self.create_test_image(channels=3) # RGB
image2 = self.create_test_image(channels=4) # RGBA
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_channel_matching_rgba_to_rgb(self):
"""Test that channel differences are handled (RGBA + RGB)"""
node = ImageStitch()
image1 = self.create_test_image(channels=4) # RGBA
image2 = self.create_test_image(channels=3) # RGB
result = node.stitch(image1, "right", False, 0, "white", image2)
# Should have 4 channels (RGBA)
assert result[0].shape[-1] == 4
def test_all_color_options(self):
"""Test all available color options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
colors = ["white", "black", "red", "green", "blue"]
for color in colors:
result = node.stitch(image1, "right", False, 16, color, image2)
assert result[0].shape == (1, 32, 80, 3) # Basic shape check
def test_all_directions(self):
"""Test all direction options"""
node = ImageStitch()
image1 = self.create_test_image(height=32, width=32)
image2 = self.create_test_image(height=32, width=32)
directions = ["right", "left", "up", "down"]
for direction in directions:
result = node.stitch(image1, direction, False, 0, "white", image2)
assert result[0].shape == (1, 32, 64, 3) if direction in ["right", "left"] else (1, 64, 32, 3)
def test_batch_size_channel_spacing_integration(self):
"""Test integration of batch matching, channel matching, size matching, and spacings"""
node = ImageStitch()
image1 = self.create_test_image(batch_size=2, height=64, width=48, channels=3)
image2 = self.create_test_image(batch_size=1, height=32, width=32, channels=4)
result = node.stitch(image1, "right", True, 8, "red", image2)
# Should handle: batch matching, size matching, channel matching, spacing
assert result[0].shape[0] == 2 # Batch size matched
assert result[0].shape[-1] == 4 # Channels matched to max
assert result[0].shape[1] == 64 # Height from image1 (size matching)
# Width should be: 48 + 8 (spacing) + resized_image2_width
expected_image2_width = int(64 * (32/32)) # Resized to height 64
expected_total_width = 48 + 8 + expected_image2_width
assert result[0].shape[2] == expected_total_width