Compare commits

...

5 Commits

Author SHA1 Message Date
comfyanonymous
c6812947e9 Fix potential memory leak. 2024-08-26 02:07:32 -04:00
comfyanonymous
9230f65823 Fix some controlnets OOMing when loading. 2024-08-25 05:54:29 -04:00
guill
6ab1e6fd4a [Bug #4529] Fix graph partial validation failure (#4588)
Currently, if a graph partially fails validation (i.e. some outputs are
valid while others have links from missing nodes), the execution loop
could get an exception resulting in server lockup.

This isn't actually possible to reproduce via the default UI, but is a
potential issue for people using the API to construct invalid graphs.
2024-08-24 15:34:58 -04:00
comfyanonymous
07dcbc3a3e Clarify how to use high quality previews. 2024-08-24 02:31:03 -04:00
comfyanonymous
8ae23d8e80 Fix onnx export. 2024-08-23 17:52:47 -04:00
6 changed files with 60 additions and 18 deletions

View File

@@ -230,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
Use ```--preview-method auto``` to enable previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
## How to use TLS/SSL?
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`

View File

@@ -391,7 +391,8 @@ def controlnet_config(sd):
else:
operations = comfy.ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
offload_device = comfy.model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
def load_controlnet_mmdit(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.SD3()
@@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = comfy.latent_formats.SDXL()
@@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
if manual_cast_dtype is not None:
controlnet_config["operations"] = comfy.ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = comfy.model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)

View File

@@ -319,12 +319,21 @@ class ModelPatcher:
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
load_completely = []
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
module_mem = x[0]
lowvram_weight = False
if not full_load and hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
@@ -356,9 +365,8 @@ class ModelPatcher:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
mem_used = comfy.model_management.module_size(m)
mem_counter += mem_used
load_completely.append((mem_used, n, m))
mem_counter += module_mem
load_completely.append((module_mem, n, m))
load_completely.sort(reverse=True)
for x in load_completely:

View File

@@ -20,9 +20,13 @@ import torch
import comfy.model_management
from comfy.cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True):
if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
return weight
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r

View File

@@ -56,6 +56,8 @@ class CacheKeySetID(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
@@ -74,6 +76,8 @@ class CacheKeySetInputSignature(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
@@ -87,6 +91,9 @@ class CacheKeySetInputSignature(CacheKeySet):
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@@ -112,6 +119,8 @@ class CacheKeySetInputSignature(CacheKeySet):
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:

View File

@@ -357,6 +357,25 @@ class TestExecution:
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
# We have multiple outputs. The first is invalid, but the second is valid
g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme")
client.run(g)
# Add back in the missing node to make sure the error doesn't break the server
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
client.run(g)
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
g = builder
# Creating the nodes in this specific order previously caused a bug
@@ -450,8 +469,8 @@ class TestExecution:
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
output1 = g.node("PreviewImage", images=input1.out(0))
output2 = g.node("PreviewImage", images=input1.out(0))
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input1.out(0))
result = client.run(g)
images1 = result.get_images(output1)