Compare commits
10 Commits
required_f
...
annoate_ge
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
522d923948 | ||
|
|
c05c9b552b | ||
|
|
27598702e9 | ||
|
|
8edc1f44c1 | ||
|
|
eade1551bb | ||
|
|
581a9991ff | ||
|
|
e471c726e5 | ||
|
|
75c1c757d9 | ||
|
|
ce9b084279 | ||
|
|
2206246055 |
@@ -69,6 +69,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
|||||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||||
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
|
||||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||||
|
- 3D Models
|
||||||
|
- [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
|
||||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||||
- Asynchronous Queue system
|
- Asynchronous Queue system
|
||||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||||
|
|||||||
@@ -22,28 +22,46 @@ import app.logger
|
|||||||
# The path to the requirements.txt file
|
# The path to the requirements.txt file
|
||||||
req_path = Path(__file__).parents[1] / "requirements.txt"
|
req_path = Path(__file__).parents[1] / "requirements.txt"
|
||||||
|
|
||||||
|
|
||||||
def frontend_install_warning_message():
|
def frontend_install_warning_message():
|
||||||
"""The warning message to display when the frontend version is not up to date."""
|
"""The warning message to display when the frontend version is not up to date."""
|
||||||
|
|
||||||
extra = ""
|
extra = ""
|
||||||
if sys.flags.no_user_site:
|
if sys.flags.no_user_site:
|
||||||
extra = "-s "
|
extra = "-s "
|
||||||
return f"Please install the updated requirements.txt file by running:\n{sys.executable} {extra}-m pip install -r {req_path}\n\nThis error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.\n\nIf you are on the portable package you can run: update\\update_comfyui.bat to solve this problem"
|
return f"""
|
||||||
|
Please install the updated requirements.txt file by running:
|
||||||
|
{sys.executable} {extra}-m pip install -r {req_path}
|
||||||
|
|
||||||
|
This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
|
||||||
|
|
||||||
def parse_version(version: str) -> tuple[int, int, int]:
|
If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
|
||||||
return tuple(map(int, version.split(".")))
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
def check_frontend_version():
|
def check_frontend_version():
|
||||||
"""Check if the frontend version is up to date."""
|
"""Check if the frontend version is up to date."""
|
||||||
|
|
||||||
|
def parse_version(version: str) -> tuple[int, int, int]:
|
||||||
|
return tuple(map(int, version.split(".")))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
frontend_version_str = version("comfyui-frontend-package")
|
frontend_version_str = version("comfyui-frontend-package")
|
||||||
frontend_version = parse_version(frontend_version_str)
|
frontend_version = parse_version(frontend_version_str)
|
||||||
with open(req_path, "r", encoding="utf-8") as f:
|
with open(req_path, "r", encoding="utf-8") as f:
|
||||||
required_frontend = parse_version(f.readline().split("=")[-1])
|
required_frontend = parse_version(f.readline().split("=")[-1])
|
||||||
if frontend_version < required_frontend:
|
if frontend_version < required_frontend:
|
||||||
app.logger.log_startup_warning("________________________________________________________________________\nWARNING WARNING WARNING WARNING WARNING\n\nInstalled frontend version {} is lower than the recommended version {}.\n\n{}\n________________________________________________________________________".format('.'.join(map(str, frontend_version)), '.'.join(map(str, required_frontend)), frontend_install_warning_message()))
|
app.logger.log_startup_warning(
|
||||||
|
f"""
|
||||||
|
________________________________________________________________________
|
||||||
|
WARNING WARNING WARNING WARNING WARNING
|
||||||
|
|
||||||
|
Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
|
||||||
|
|
||||||
|
{frontend_install_warning_message()}
|
||||||
|
________________________________________________________________________
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -73,11 +91,6 @@ class FrontEndProvider:
|
|||||||
owner: str
|
owner: str
|
||||||
repo: str
|
repo: str
|
||||||
|
|
||||||
@property
|
|
||||||
def is_official(self) -> bool:
|
|
||||||
"""Check if the provider is the default official one."""
|
|
||||||
return self.owner == "Comfy-Org" and self.repo == "ComfyUI_frontend"
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def folder_name(self) -> str:
|
def folder_name(self) -> str:
|
||||||
return f"{self.owner}_{self.repo}"
|
return f"{self.owner}_{self.repo}"
|
||||||
@@ -148,26 +161,27 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
|||||||
zip_ref.extractall(destination_path)
|
zip_ref.extractall(destination_path)
|
||||||
|
|
||||||
|
|
||||||
class FrontendInit(TypedDict):
|
|
||||||
web_root: str
|
|
||||||
""" The path to the initialized frontend. """
|
|
||||||
version: tuple[int, int, int] | None
|
|
||||||
""" The version of the initialized frontend. None for unrecognized version."""
|
|
||||||
|
|
||||||
class FrontendManager:
|
class FrontendManager:
|
||||||
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_default_frontend(cls) -> FrontendInit:
|
def default_frontend_path(cls) -> str:
|
||||||
check_frontend_version()
|
|
||||||
try:
|
try:
|
||||||
import comfyui_frontend_package
|
import comfyui_frontend_package
|
||||||
return FrontendInit(
|
|
||||||
web_root=str(importlib.resources.files(comfyui_frontend_package) / "static"),
|
return str(importlib.resources.files(comfyui_frontend_package) / "static")
|
||||||
version=parse_version(version("comfyui-frontend-package")),
|
|
||||||
)
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logging.error(f"\n\n********** ERROR ***********\n\ncomfyui-frontend-package is not installed. {frontend_install_warning_message()}\n********** ERROR **********\n")
|
logging.error(
|
||||||
|
f"""
|
||||||
|
********** ERROR ***********
|
||||||
|
|
||||||
|
comfyui-frontend-package is not installed.
|
||||||
|
|
||||||
|
{frontend_install_warning_message()}
|
||||||
|
|
||||||
|
********** ERROR ***********
|
||||||
|
""".strip()
|
||||||
|
)
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -190,7 +204,9 @@ class FrontendManager:
|
|||||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> FrontendInit:
|
def init_frontend_unsafe(
|
||||||
|
cls, version_string: str, provider: Optional[FrontEndProvider] = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend for the specified version.
|
Initializes the frontend for the specified version.
|
||||||
|
|
||||||
@@ -206,17 +222,26 @@ class FrontendManager:
|
|||||||
main error source might be request timeout or invalid URL.
|
main error source might be request timeout or invalid URL.
|
||||||
"""
|
"""
|
||||||
if version_string == DEFAULT_VERSION_STRING:
|
if version_string == DEFAULT_VERSION_STRING:
|
||||||
return cls.init_default_frontend()
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|
||||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||||
|
|
||||||
if version.startswith("v"):
|
if version.startswith("v"):
|
||||||
expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
|
expected_path = str(
|
||||||
|
Path(cls.CUSTOM_FRONTENDS_ROOT)
|
||||||
|
/ f"{repo_owner}_{repo_name}"
|
||||||
|
/ version.lstrip("v")
|
||||||
|
)
|
||||||
if os.path.exists(expected_path):
|
if os.path.exists(expected_path):
|
||||||
logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
|
logging.info(
|
||||||
|
f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
|
||||||
|
)
|
||||||
return expected_path
|
return expected_path
|
||||||
|
|
||||||
logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
|
logging.info(
|
||||||
|
f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
|
||||||
|
)
|
||||||
|
|
||||||
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
provider = provider or FrontEndProvider(repo_owner, repo_name)
|
||||||
release = provider.get_release(version)
|
release = provider.get_release(version)
|
||||||
@@ -241,13 +266,10 @@ class FrontendManager:
|
|||||||
if not os.listdir(web_root):
|
if not os.listdir(web_root):
|
||||||
os.rmdir(web_root)
|
os.rmdir(web_root)
|
||||||
|
|
||||||
return FrontendInit(
|
return web_root
|
||||||
web_root=web_root,
|
|
||||||
version=parse_version(semantic_version) if provider.is_official else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def init_frontend(cls, version_string: str) -> FrontendInit:
|
def init_frontend(cls, version_string: str) -> str:
|
||||||
"""
|
"""
|
||||||
Initializes the frontend with the specified version string.
|
Initializes the frontend with the specified version string.
|
||||||
|
|
||||||
@@ -262,4 +284,5 @@ class FrontendManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Failed to initialize frontend: %s", e)
|
logging.error("Failed to initialize frontend: %s", e)
|
||||||
logging.info("Falling back to the default frontend.")
|
logging.info("Falling back to the default frontend.")
|
||||||
return cls.init_default_frontend()
|
check_frontend_version()
|
||||||
|
return cls.default_frontend_path()
|
||||||
|
|||||||
@@ -220,13 +220,6 @@ class ComfyNodeABC(ABC):
|
|||||||
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
"""Flags a node as experimental, informing users that it may change or not work as expected."""
|
||||||
DEPRECATED: bool
|
DEPRECATED: bool
|
||||||
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
|
||||||
REQUIRED_FRONTEND_VERSION: str | None
|
|
||||||
"""The minimum version of the ComfyUI frontend required to load this node.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
REQUIRED_FRONTEND_VERSION = "1.9.7"
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|||||||
@@ -471,7 +471,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
|||||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
|
||||||
if skip_reshape:
|
if skip_reshape:
|
||||||
b, _, _, dim_head = q.shape
|
b, _, _, dim_head = q.shape
|
||||||
tensor_layout="HND"
|
tensor_layout = "HND"
|
||||||
else:
|
else:
|
||||||
b, _, dim_head = q.shape
|
b, _, dim_head = q.shape
|
||||||
dim_head //= heads
|
dim_head //= heads
|
||||||
@@ -479,7 +479,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
lambda t: t.view(b, -1, heads, dim_head),
|
lambda t: t.view(b, -1, heads, dim_head),
|
||||||
(q, k, v),
|
(q, k, v),
|
||||||
)
|
)
|
||||||
tensor_layout="NHD"
|
tensor_layout = "NHD"
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# add a batch dimension if there isn't already one
|
# add a batch dimension if there isn't already one
|
||||||
@@ -489,7 +489,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
|||||||
if mask.ndim == 3:
|
if mask.ndim == 3:
|
||||||
mask = mask.unsqueeze(1)
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
try:
|
||||||
|
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||||
|
if tensor_layout == "NHD":
|
||||||
|
q, k, v = map(
|
||||||
|
lambda t: t.transpose(1, 2),
|
||||||
|
(q, k, v),
|
||||||
|
)
|
||||||
|
return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
|
||||||
|
|
||||||
if tensor_layout == "HND":
|
if tensor_layout == "HND":
|
||||||
if not skip_output_reshape:
|
if not skip_output_reshape:
|
||||||
out = (
|
out = (
|
||||||
|
|||||||
@@ -46,6 +46,32 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
|
def get_supported_float8_types():
|
||||||
|
float8_types = []
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e4m3fn)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e4m3fnuz)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e5m2)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e5m2fnuz)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
float8_types.append(torch.float8_e8m0fnu)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return float8_types
|
||||||
|
|
||||||
|
FLOAT8_TYPES = get_supported_float8_types()
|
||||||
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
torch_version = ""
|
torch_version = ""
|
||||||
try:
|
try:
|
||||||
@@ -701,11 +727,8 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
return torch.float8_e5m2
|
return torch.float8_e5m2
|
||||||
|
|
||||||
fp8_dtype = None
|
fp8_dtype = None
|
||||||
try:
|
if weight_dtype in FLOAT8_TYPES:
|
||||||
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
fp8_dtype = weight_dtype
|
||||||
fp8_dtype = weight_dtype
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if fp8_dtype is not None:
|
if fp8_dtype is not None:
|
||||||
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import nodes
|
from __future__ import annotations
|
||||||
|
from typing import Type, Literal
|
||||||
|
|
||||||
|
import nodes
|
||||||
from comfy_execution.graph_utils import is_link
|
from comfy_execution.graph_utils import is_link
|
||||||
|
from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
|
||||||
|
|
||||||
class DependencyCycleError(Exception):
|
class DependencyCycleError(Exception):
|
||||||
pass
|
pass
|
||||||
@@ -54,7 +57,22 @@ class DynamicPrompt:
|
|||||||
def get_original_prompt(self):
|
def get_original_prompt(self):
|
||||||
return self.original_prompt
|
return self.original_prompt
|
||||||
|
|
||||||
def get_input_info(class_def, input_name, valid_inputs=None):
|
def get_input_info(
|
||||||
|
class_def: Type[ComfyNodeABC],
|
||||||
|
input_name: str,
|
||||||
|
valid_inputs: InputTypeDict | None = None
|
||||||
|
) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
|
||||||
|
"""Get the input type, category, and extra info for a given input name.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
class_def: The class definition of the node.
|
||||||
|
input_name: The name of the input to get info for.
|
||||||
|
valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
|
||||||
|
"""
|
||||||
|
|
||||||
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
|
||||||
input_info = None
|
input_info = None
|
||||||
input_category = None
|
input_category = None
|
||||||
@@ -126,7 +144,7 @@ class TopologicalSort:
|
|||||||
from_node_id, from_socket = value
|
from_node_id, from_socket = value
|
||||||
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
|
||||||
continue
|
continue
|
||||||
input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
|
_, _, input_info = self.get_input_info(unique_id, input_name)
|
||||||
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
|
||||||
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
|
||||||
node_ids.append(from_node_id)
|
node_ids.append(from_node_id)
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ class Load3D():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -32,12 +32,16 @@ class Load3D():
|
|||||||
def process(self, model_file, image, **kwargs):
|
def process(self, model_file, image, **kwargs):
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
|
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||||
|
lineart_path = folder_paths.get_annotated_filepath(image['lineart'])
|
||||||
|
|
||||||
load_image_node = nodes.LoadImage()
|
load_image_node = nodes.LoadImage()
|
||||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file,
|
return output_image, output_mask, model_file, normal_image, lineart_image
|
||||||
|
|
||||||
class Load3DAnimation():
|
class Load3DAnimation():
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -55,8 +59,8 @@ class Load3DAnimation():
|
|||||||
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
"height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK", "STRING")
|
RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE")
|
||||||
RETURN_NAMES = ("image", "mask", "mesh_path")
|
RETURN_NAMES = ("image", "mask", "mesh_path", "normal")
|
||||||
|
|
||||||
FUNCTION = "process"
|
FUNCTION = "process"
|
||||||
EXPERIMENTAL = True
|
EXPERIMENTAL = True
|
||||||
@@ -66,12 +70,14 @@ class Load3DAnimation():
|
|||||||
def process(self, model_file, image, **kwargs):
|
def process(self, model_file, image, **kwargs):
|
||||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||||
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
mask_path = folder_paths.get_annotated_filepath(image['mask'])
|
||||||
|
normal_path = folder_paths.get_annotated_filepath(image['normal'])
|
||||||
|
|
||||||
load_image_node = nodes.LoadImage()
|
load_image_node = nodes.LoadImage()
|
||||||
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
output_image, ignore_mask = load_image_node.load_image(image=image_path)
|
||||||
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
ignore_image, output_mask = load_image_node.load_image(image=mask_path)
|
||||||
|
normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path)
|
||||||
|
|
||||||
return output_image, output_mask, model_file,
|
return output_image, output_mask, model_file, normal_image
|
||||||
|
|
||||||
class Preview3D():
|
class Preview3D():
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -244,6 +244,30 @@ class ModelMergeCosmos14B(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
|||||||
|
|
||||||
return {"required": arg_dict}
|
return {"required": arg_dict}
|
||||||
|
|
||||||
|
class ModelMergeWAN2_1(comfy_extras.nodes_model_merging.ModelMergeBlocks):
|
||||||
|
CATEGORY = "advanced/model_merging/model_specific"
|
||||||
|
DESCRIPTION = "1.3B model has 30 blocks, 14B model has 40 blocks. Image to video model has the extra img_emb."
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
arg_dict = { "model1": ("MODEL",),
|
||||||
|
"model2": ("MODEL",)}
|
||||||
|
|
||||||
|
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
|
||||||
|
|
||||||
|
arg_dict["patch_embedding."] = argument
|
||||||
|
arg_dict["time_embedding."] = argument
|
||||||
|
arg_dict["time_projection."] = argument
|
||||||
|
arg_dict["text_embedding."] = argument
|
||||||
|
arg_dict["img_emb."] = argument
|
||||||
|
|
||||||
|
for i in range(40):
|
||||||
|
arg_dict["blocks.{}.".format(i)] = argument
|
||||||
|
|
||||||
|
arg_dict["head."] = argument
|
||||||
|
|
||||||
|
return {"required": arg_dict}
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ModelMergeSD1": ModelMergeSD1,
|
"ModelMergeSD1": ModelMergeSD1,
|
||||||
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
|
||||||
@@ -256,4 +280,5 @@ NODE_CLASS_MAPPINGS = {
|
|||||||
"ModelMergeLTXV": ModelMergeLTXV,
|
"ModelMergeLTXV": ModelMergeLTXV,
|
||||||
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
"ModelMergeCosmos7B": ModelMergeCosmos7B,
|
||||||
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
"ModelMergeCosmos14B": ModelMergeCosmos14B,
|
||||||
|
"ModelMergeWAN2_1": ModelMergeWAN2_1,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.3.26"
|
__version__ = "0.3.27"
|
||||||
|
|||||||
31
execution.py
31
execution.py
@@ -93,7 +93,7 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e
|
|||||||
missing_keys = {}
|
missing_keys = {}
|
||||||
for x in inputs:
|
for x in inputs:
|
||||||
input_data = inputs[x]
|
input_data = inputs[x]
|
||||||
input_type, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
_, input_category, input_info = get_input_info(class_def, x, valid_inputs)
|
||||||
def mark_missing():
|
def mark_missing():
|
||||||
missing_keys[x] = True
|
missing_keys[x] = True
|
||||||
input_data_all[x] = (None,)
|
input_data_all[x] = (None,)
|
||||||
@@ -555,7 +555,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
received_types = {}
|
received_types = {}
|
||||||
|
|
||||||
for x in valid_inputs:
|
for x in valid_inputs:
|
||||||
type_input, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs)
|
||||||
assert extra_info is not None
|
assert extra_info is not None
|
||||||
if x not in inputs:
|
if x not in inputs:
|
||||||
if input_category == "required":
|
if input_category == "required":
|
||||||
@@ -571,7 +571,7 @@ def validate_inputs(prompt, item, validated):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
val = inputs[x]
|
val = inputs[x]
|
||||||
info = (type_input, extra_info)
|
info = (input_type, extra_info)
|
||||||
if isinstance(val, list):
|
if isinstance(val, list):
|
||||||
if len(val) != 2:
|
if len(val) != 2:
|
||||||
error = {
|
error = {
|
||||||
@@ -592,8 +592,8 @@ def validate_inputs(prompt, item, validated):
|
|||||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||||
received_type = r[val[1]]
|
received_type = r[val[1]]
|
||||||
received_types[x] = received_type
|
received_types[x] = received_type
|
||||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, input_type):
|
||||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
details = f"{x}, received_type({received_type}) mismatch input_type({input_type})"
|
||||||
error = {
|
error = {
|
||||||
"type": "return_type_mismatch",
|
"type": "return_type_mismatch",
|
||||||
"message": "Return type mismatch between linked nodes",
|
"message": "Return type mismatch between linked nodes",
|
||||||
@@ -641,22 +641,22 @@ def validate_inputs(prompt, item, validated):
|
|||||||
val = val["__value__"]
|
val = val["__value__"]
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
|
|
||||||
if type_input == "INT":
|
if input_type == "INT":
|
||||||
val = int(val)
|
val = int(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "FLOAT":
|
if input_type == "FLOAT":
|
||||||
val = float(val)
|
val = float(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "STRING":
|
if input_type == "STRING":
|
||||||
val = str(val)
|
val = str(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
if type_input == "BOOLEAN":
|
if input_type == "BOOLEAN":
|
||||||
val = bool(val)
|
val = bool(val)
|
||||||
inputs[x] = val
|
inputs[x] = val
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
error = {
|
error = {
|
||||||
"type": "invalid_input_type",
|
"type": "invalid_input_type",
|
||||||
"message": f"Failed to convert an input value to a {type_input} value",
|
"message": f"Failed to convert an input value to a {input_type} value",
|
||||||
"details": f"{x}, {val}, {ex}",
|
"details": f"{x}, {val}, {ex}",
|
||||||
"extra_info": {
|
"extra_info": {
|
||||||
"input_name": x,
|
"input_name": x,
|
||||||
@@ -696,18 +696,19 @@ def validate_inputs(prompt, item, validated):
|
|||||||
errors.append(error)
|
errors.append(error)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if isinstance(type_input, list):
|
if isinstance(input_type, list):
|
||||||
if val not in type_input:
|
combo_options = input_type
|
||||||
|
if val not in combo_options:
|
||||||
input_config = info
|
input_config = info
|
||||||
list_info = ""
|
list_info = ""
|
||||||
|
|
||||||
# Don't send back gigantic lists like if they're lots of
|
# Don't send back gigantic lists like if they're lots of
|
||||||
# scanned model filepaths
|
# scanned model filepaths
|
||||||
if len(type_input) > 20:
|
if len(combo_options) > 20:
|
||||||
list_info = f"(list of length {len(type_input)})"
|
list_info = f"(list of length {len(combo_options)})"
|
||||||
input_config = None
|
input_config = None
|
||||||
else:
|
else:
|
||||||
list_info = str(type_input)
|
list_info = str(combo_options)
|
||||||
|
|
||||||
error = {
|
error = {
|
||||||
"type": "value_not_in_list",
|
"type": "value_not_in_list",
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.3.26"
|
version = "0.3.27"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.9"
|
requires-python = ">=3.9"
|
||||||
|
|||||||
47
server.py
47
server.py
@@ -1,4 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -25,12 +24,11 @@ import logging
|
|||||||
|
|
||||||
import mimetypes
|
import mimetypes
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
from comfy.comfy_types.node_typing import ComfyNodeABC
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import node_helpers
|
import node_helpers
|
||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendInit, FrontendManager, parse_version
|
from app.frontend_management import FrontendManager
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
from app.custom_node_manager import CustomNodeManager
|
from app.custom_node_manager import CustomNodeManager
|
||||||
@@ -148,11 +146,6 @@ def create_origin_only_middleware():
|
|||||||
return origin_only_middleware
|
return origin_only_middleware
|
||||||
|
|
||||||
class PromptServer():
|
class PromptServer():
|
||||||
web_root: str
|
|
||||||
"""The path to the initialized frontend assets."""
|
|
||||||
frontend_version: tuple[int, int, int] | None = None
|
|
||||||
"""The version of the initialized frontend. None for unrecognized version."""
|
|
||||||
|
|
||||||
def __init__(self, loop):
|
def __init__(self, loop):
|
||||||
PromptServer.instance = self
|
PromptServer.instance = self
|
||||||
|
|
||||||
@@ -183,19 +176,12 @@ class PromptServer():
|
|||||||
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
max_upload_size = round(args.max_upload_size * 1024 * 1024)
|
||||||
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
self.app = web.Application(client_max_size=max_upload_size, middlewares=middlewares)
|
||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
|
self.web_root = (
|
||||||
if args.front_end_root:
|
FrontendManager.init_frontend(args.front_end_version)
|
||||||
frontend_init = FrontendInit(
|
if args.front_end_root is None
|
||||||
web_root=args.front_end_root,
|
else args.front_end_root
|
||||||
version=None,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
frontend_init = FrontendManager.init_frontend(args.front_end_version)
|
|
||||||
|
|
||||||
self.frontend_version = frontend_init["version"]
|
|
||||||
self.web_root = frontend_init["web_root"]
|
|
||||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||||
|
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes = routes
|
self.routes = routes
|
||||||
self.last_node_id = None
|
self.last_node_id = None
|
||||||
@@ -601,9 +587,6 @@ class PromptServer():
|
|||||||
with folder_paths.cache_helper:
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
if not self.node_is_supported(x):
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
out[x] = node_info(x)
|
out[x] = node_info(x)
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -615,11 +598,7 @@ class PromptServer():
|
|||||||
async def get_object_info_node(request):
|
async def get_object_info_node(request):
|
||||||
node_class = request.match_info.get("node_class", None)
|
node_class = request.match_info.get("node_class", None)
|
||||||
out = {}
|
out = {}
|
||||||
if (
|
if (node_class is not None) and (node_class in nodes.NODE_CLASS_MAPPINGS):
|
||||||
node_class is not None
|
|
||||||
and node_class in nodes.NODE_CLASS_MAPPINGS
|
|
||||||
and self.node_is_supported(node_class)
|
|
||||||
):
|
|
||||||
out[node_class] = node_info(node_class)
|
out[node_class] = node_info(node_class)
|
||||||
return web.json_response(out)
|
return web.json_response(out)
|
||||||
|
|
||||||
@@ -884,15 +863,3 @@ class PromptServer():
|
|||||||
logging.warning(traceback.format_exc())
|
logging.warning(traceback.format_exc())
|
||||||
|
|
||||||
return json_data
|
return json_data
|
||||||
|
|
||||||
def node_is_supported(self, node_class: ComfyNodeABC) -> bool:
|
|
||||||
"""Check if the node is supported by the frontend."""
|
|
||||||
# For unrecognized frontend version, we assume the node is supported.
|
|
||||||
if self.frontend_version is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if the node is supported by the frontend.
|
|
||||||
if node_class.REQUIRED_FRONTEND_VERSION is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return parse_version(node_class.REQUIRED_FRONTEND_VERSION) <= self.frontend_version
|
|
||||||
|
|||||||
@@ -69,10 +69,8 @@ def test_get_release_invalid_version(mock_provider):
|
|||||||
|
|
||||||
def test_init_frontend_default():
|
def test_init_frontend_default():
|
||||||
version_string = DEFAULT_VERSION_STRING
|
version_string = DEFAULT_VERSION_STRING
|
||||||
frontend_init = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
assert isinstance(frontend_init, dict)
|
assert frontend_path == FrontendManager.default_frontend_path()
|
||||||
assert "web_root" in frontend_init
|
|
||||||
assert "version" in frontend_init
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_frontend_invalid_version():
|
def test_init_frontend_invalid_version():
|
||||||
@@ -140,47 +138,37 @@ def test_parse_version_string_invalid():
|
|||||||
def test_init_frontend_default_with_mocks():
|
def test_init_frontend_default_with_mocks():
|
||||||
# Arrange
|
# Arrange
|
||||||
version_string = DEFAULT_VERSION_STRING
|
version_string = DEFAULT_VERSION_STRING
|
||||||
mock_path = "/mocked/path"
|
|
||||||
mock_version = (1, 0, 0)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with (
|
with (
|
||||||
patch("app.frontend_management.check_frontend_version") as mock_check,
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
patch.object(
|
patch.object(
|
||||||
FrontendManager,
|
FrontendManager, "default_frontend_path", return_value="/mocked/path"
|
||||||
"init_default_frontend",
|
|
||||||
return_value={"web_root": mock_path, "version": mock_version},
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
frontend_init = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert frontend_init["web_root"] == mock_path
|
assert frontend_path == "/mocked/path"
|
||||||
assert frontend_init["version"] == mock_version
|
mock_check.assert_called_once()
|
||||||
mock_check.assert_not_called() # check_frontend_version is now called inside init_default_frontend
|
|
||||||
|
|
||||||
|
|
||||||
def test_init_frontend_fallback_on_error():
|
def test_init_frontend_fallback_on_error():
|
||||||
# Arrange
|
# Arrange
|
||||||
version_string = "test-owner/test-repo@1.0.0"
|
version_string = "test-owner/test-repo@1.0.0"
|
||||||
mock_path = "/default/path"
|
|
||||||
mock_version = (1, 0, 0)
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
with (
|
with (
|
||||||
patch.object(
|
patch.object(
|
||||||
FrontendManager,
|
FrontendManager, "init_frontend_unsafe", side_effect=Exception("Test error")
|
||||||
"init_frontend_unsafe",
|
|
||||||
side_effect=Exception("Test error")
|
|
||||||
),
|
),
|
||||||
|
patch("app.frontend_management.check_frontend_version") as mock_check,
|
||||||
patch.object(
|
patch.object(
|
||||||
FrontendManager,
|
FrontendManager, "default_frontend_path", return_value="/default/path"
|
||||||
"init_default_frontend",
|
|
||||||
return_value={"web_root": mock_path, "version": mock_version},
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
frontend_init = FrontendManager.init_frontend(version_string)
|
frontend_path = FrontendManager.init_frontend(version_string)
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert frontend_init["web_root"] == mock_path
|
assert frontend_path == "/default/path"
|
||||||
assert frontend_init["version"] == mock_version
|
mock_check.assert_called_once()
|
||||||
|
|||||||
Reference in New Issue
Block a user