Compare commits

...

10 Commits

Author SHA1 Message Date
Chenlei Hu
522d923948 nit 2025-03-25 16:47:52 -04:00
Chenlei Hu
c05c9b552b nit 2025-03-25 16:47:42 -04:00
Chenlei Hu
27598702e9 [Type] Annotate graph.get_input_info 2025-03-25 16:44:55 -04:00
comfyanonymous
8edc1f44c1 Support more float8 types. 2025-03-25 05:23:49 -04:00
comfyanonymous
eade1551bb Add Hunyuan3D to readme. 2025-03-24 07:14:32 -04:00
comfyanonymous
581a9991ff Add model merging node for WAN 2.1 2025-03-23 08:06:36 -04:00
comfyanonymous
e471c726e5 Fallback to pytorch attention if sage attention fails. 2025-03-22 15:45:56 -04:00
comfyanonymous
75c1c757d9 ComfyUI version v0.3.27 2025-03-21 20:09:54 -04:00
Chenlei Hu
ce9b084279 [nit] Format error strings (#7345) 2025-03-21 19:08:25 -04:00
Terry Jia
2206246055 support output normal and lineart once (#7290) 2025-03-21 16:24:13 -04:00
10 changed files with 165 additions and 41 deletions

View File

@@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- 3D Models
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- Asynchronous Queue system - Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions. - Many optimizations: Only re-executes the parts of the workflow that changes between executions.

View File

@@ -22,13 +22,21 @@ import app.logger
# The path to the requirements.txt file # The path to the requirements.txt file
req_path = Path(__file__).parents[1] / "requirements.txt" req_path = Path(__file__).parents[1] / "requirements.txt"
def frontend_install_warning_message(): def frontend_install_warning_message():
"""The warning message to display when the frontend version is not up to date.""" """The warning message to display when the frontend version is not up to date."""
extra = "" extra = ""
if sys.flags.no_user_site: if sys.flags.no_user_site:
extra = "-s " extra = "-s "
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem" return f"""
Please install the updated requirements.txt file by running:
{sys.executable} {extra}-m pip install -r {req_path}
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
""".strip()
def check_frontend_version(): def check_frontend_version():
@@ -43,7 +51,17 @@ def check_frontend_version():
with open(req_path, "r", encoding="utf-8") as f: with open(req_path, "r", encoding="utf-8") as f:
required_frontend = parse_version(f.readline().split("=")[-1]) required_frontend = parse_version(f.readline().split("=")[-1])
if frontend_version < required_frontend: if frontend_version < required_frontend:
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message())) app.logger.log_startup_warning(
f"""
________________________________________________________________________
WARNING WARNING WARNING WARNING WARNING
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
{frontend_install_warning_message()}
________________________________________________________________________
""".strip()
)
else: else:
logging.info("ComfyUI frontend version: {}".format(frontend_version_str)) logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
except Exception as e: except Exception as e:
@@ -150,9 +168,20 @@ class FrontendManager:
def default_frontend_path(cls) -> str: def default_frontend_path(cls) -> str:
try: try:
import comfyui_frontend_package import comfyui_frontend_package
return str(importlib.resources.files(comfyui_frontend_package) / "static") return str(importlib.resources.files(comfyui_frontend_package) / "static")
except ImportError: except ImportError:
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n") logging.error(
f"""
********** ERROR ***********
comfyui-frontend-package is not installed.
{frontend_install_warning_message()}
********** ERROR ***********
""".strip()
)
sys.exit(-1) sys.exit(-1)
@classmethod @classmethod
@@ -175,7 +204,9 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3) return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod @classmethod
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str: def init_frontend_unsafe(
cls, version_string: str, provider: Optional[FrontEndProvider] = None
) -> str:
""" """
Initializes the frontend for the specified version. Initializes the frontend for the specified version.
@@ -197,12 +228,20 @@ class FrontendManager:
repo_owner, repo_name, version = cls.parse_version_string(version_string) repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"): if version.startswith("v"):
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v")) expected_path = str(
Path(cls.CUSTOM_FRONTENDS_ROOT)
/ f"{repo_owner}_{repo_name}"
/ version.lstrip("v")
)
if os.path.exists(expected_path): if os.path.exists(expected_path):
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}") logging.info(
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
)
return expected_path return expected_path
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...") logging.info(
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
)
provider = provider or FrontEndProvider(repo_owner, repo_name) provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version) release = provider.get_release(version)

View File

@@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False): def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape: if skip_reshape:
b, _, _, dim_head = q.shape b, _, _, dim_head = q.shape
tensor_layout="HND" tensor_layout = "HND"
else: else:
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
@@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head), lambda t: t.view(b, -1, heads, dim_head),
(q, k, v), (q, k, v),
) )
tensor_layout="NHD" tensor_layout = "NHD"
if mask is not None: if mask is not None:
# add a batch dimension if there isn't already one # add a batch dimension if there isn't already one
@@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
try:
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
except Exception as e:
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
if tensor_layout == "NHD":
q, k, v = map(
lambda t: t.transpose(1, 2),
(q, k, v),
)
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
if tensor_layout == "HND": if tensor_layout == "HND":
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (

View File

@@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
total_vram = 0 total_vram = 0
def get_supported_float8_types():
float8_types = []
try:
float8_types.append(torch.float8_e4m3fn)
except:
pass
try:
float8_types.append(torch.float8_e4m3fnuz)
except:
pass
try:
float8_types.append(torch.float8_e5m2)
except:
pass
try:
float8_types.append(torch.float8_e5m2fnuz)
except:
pass
try:
float8_types.append(torch.float8_e8m0fnu)
except:
pass
return float8_types
FLOAT8_TYPES = get_supported_float8_types()
xpu_available = False xpu_available = False
torch_version = "" torch_version = ""
try: try:
@@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e5m2 return torch.float8_e5m2
fp8_dtype = None fp8_dtype = None
try: if weight_dtype in FLOAT8_TYPES:
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
fp8_dtype = weight_dtype fp8_dtype = weight_dtype
except:
pass
if fp8_dtype is not None: if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive

View File

@@ -1,6 +1,9 @@
import nodes from __future__ import annotations
from typing import Type, Literal
import nodes
from comfy_execution.graph_utils import is_link from comfy_execution.graph_utils import is_link
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception): class DependencyCycleError(Exception):
pass pass
@@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self): def get_original_prompt(self):
return self.original_prompt return self.original_prompt
def get_input_info(class_def, input_name, valid_inputs=None): def get_input_info(
class_def: Type[ComfyNodeABC],
input_name: str,
valid_inputs: InputTypeDict | None = None
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
"""Get the input type, category, and extra info for a given input name.
Arguments:
class_def: The class definition of the node.
input_name: The name of the input to get info for.
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
Returns:
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
"""
valid_inputs = valid_inputs or class_def.INPUT_TYPES() valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None input_info = None
input_category = None input_category = None
@@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes: if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue continue
input_type, input_category, input_info = self.get_input_info(unique_id, input_name) _, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id) node_ids.append(from_node_id)

View File

@@ -21,8 +21,8 @@ class Load3D():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path") RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@@ -32,12 +32,16 @@ class Load3D():
def process(self, model_file, image, **kwargs): def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image']) image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask']) mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
load_image_node = nodes.LoadImage() load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path) output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
return output_image, output_mask, model_file, return output_image, output_mask, model_file, normal_image, lineart_image
class Load3DAnimation(): class Load3DAnimation():
@classmethod @classmethod
@@ -55,8 +59,8 @@ class Load3DAnimation():
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE", "MASK", "STRING") RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
RETURN_NAMES = ("image", "mask", "mesh_path") RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
FUNCTION = "process" FUNCTION = "process"
EXPERIMENTAL = True EXPERIMENTAL = True
@@ -66,12 +70,14 @@ class Load3DAnimation():
def process(self, model_file, image, **kwargs): def process(self, model_file, image, **kwargs):
image_path = folder_paths.get_annotated_filepath(image['image']) image_path = folder_paths.get_annotated_filepath(image['image'])
mask_path = folder_paths.get_annotated_filepath(image['mask']) mask_path = folder_paths.get_annotated_filepath(image['mask'])
normal_path = folder_paths.get_annotated_filepath(image['normal'])
load_image_node = nodes.LoadImage() load_image_node = nodes.LoadImage()
output_image, ignore_mask = load_image_node.load_image(image=image_path) output_image, ignore_mask = load_image_node.load_image(image=image_path)
ignore_image, output_mask = load_image_node.load_image(image=mask_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path)
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
return output_image, output_mask, model_file, return output_image, output_mask, model_file, normal_image
class Preview3D(): class Preview3D():
@classmethod @classmethod

View File

@@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
return {"required": arg_dict} return {"required": arg_dict}
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
CATEGORY = "advanced/model_merging/model_specific"
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
@classmethod
def INPUT_TYPES(s):
arg_dict = { "model1": ("MODEL",),
"model2": ("MODEL",)}
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
arg_dict["patch_embedding."] = argument
arg_dict["time_embedding."] = argument
arg_dict["time_projection."] = argument
arg_dict["text_embedding."] = argument
arg_dict["img_emb."] = argument
for i in range(40):
arg_dict["blocks.{}.".format(i)] = argument
arg_dict["head."] = argument
return {"required": arg_dict}
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"ModelMergeSD1": ModelMergeSD1, "ModelMergeSD1": ModelMergeSD1,
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks "ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
"ModelMergeLTXV": ModelMergeLTXV, "ModelMergeLTXV": ModelMergeLTXV,
"ModelMergeCosmos7B": ModelMergeCosmos7B, "ModelMergeCosmos7B": ModelMergeCosmos7B,
"ModelMergeCosmos14B": ModelMergeCosmos14B, "ModelMergeCosmos14B": ModelMergeCosmos14B,
"ModelMergeWAN2_1": ModelMergeWAN2_1,
} }

View File

@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.3.26" __version__ = "0.3.27"

View File

@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
missing_keys = {} missing_keys = {}
for x in inputs: for x in inputs:
input_data = inputs[x] input_data = inputs[x]
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs) _, input_category, input_info = get_input_info(class_def, x, valid_inputs)
def mark_missing(): def mark_missing():
missing_keys[x] = True missing_keys[x] = True
input_data_all[x] = (None,) input_data_all[x] = (None,)
@@ -555,7 +555,7 @@ def validate_inputs(prompt, item, validated):
received_types = {} received_types = {}
for x in valid_inputs: for x in valid_inputs:
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs) input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
assert extra_info is not None assert extra_info is not None
if x not in inputs: if x not in inputs:
if input_category == "required": if input_category == "required":
@@ -571,7 +571,7 @@ def validate_inputs(prompt, item, validated):
continue continue
val = inputs[x] val = inputs[x]
info = (type_input, extra_info) info = (input_type, extra_info)
if isinstance(val, list): if isinstance(val, list):
if len(val) != 2: if len(val) != 2:
error = { error = {
@@ -592,8 +592,8 @@ def validate_inputs(prompt, item, validated):
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
received_type = r[val[1]] received_type = r[val[1]]
received_types[x] = received_type received_types[x] = received_type
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
error = { error = {
"type": "return_type_mismatch", "type": "return_type_mismatch",
"message": "Return type mismatch between linked nodes", "message": "Return type mismatch between linked nodes",
@@ -641,22 +641,22 @@ def validate_inputs(prompt, item, validated):
val = val["__value__"] val = val["__value__"]
inputs[x] = val inputs[x] = val
if type_input == "INT": if input_type == "INT":
val = int(val) val = int(val)
inputs[x] = val inputs[x] = val
if type_input == "FLOAT": if input_type == "FLOAT":
val = float(val) val = float(val)
inputs[x] = val inputs[x] = val
if type_input == "STRING": if input_type == "STRING":
val = str(val) val = str(val)
inputs[x] = val inputs[x] = val
if type_input == "BOOLEAN": if input_type == "BOOLEAN":
val = bool(val) val = bool(val)
inputs[x] = val inputs[x] = val
except Exception as ex: except Exception as ex:
error = { error = {
"type": "invalid_input_type", "type": "invalid_input_type",
"message": f"Failed to convert an input value to a {type_input} value", "message": f"Failed to convert an input value to a {input_type} value",
"details": f"{x}, {val}, {ex}", "details": f"{x}, {val}, {ex}",
"extra_info": { "extra_info": {
"input_name": x, "input_name": x,
@@ -696,18 +696,19 @@ def validate_inputs(prompt, item, validated):
errors.append(error) errors.append(error)
continue continue
if isinstance(type_input, list): if isinstance(input_type, list):
if val not in type_input: combo_options = input_type
if val not in combo_options:
input_config = info input_config = info
list_info = "" list_info = ""
# Don't send back gigantic lists like if they're lots of # Don't send back gigantic lists like if they're lots of
# scanned model filepaths # scanned model filepaths
if len(type_input) > 20: if len(combo_options) > 20:
list_info = f"(list of length {len(type_input)})" list_info = f"(list of length {len(combo_options)})"
input_config = None input_config = None
else: else:
list_info = str(type_input) list_info = str(combo_options)
error = { error = {
"type": "value_not_in_list", "type": "value_not_in_list",

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.3.26" version = "0.3.27"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.9" requires-python = ">=3.9"