Compare commits
23 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c6812947e9 | ||
|
|
9230f65823 | ||
|
|
6ab1e6fd4a | ||
|
|
07dcbc3a3e | ||
|
|
8ae23d8e80 | ||
|
|
7df42b9a23 | ||
|
|
5d8bbb7281 | ||
|
|
2c1d2375d6 | ||
|
|
64ccb3c7e3 | ||
|
|
9465b23432 | ||
|
|
bb4416dd5b | ||
|
|
c0b0da264b | ||
|
|
c26ca27207 | ||
|
|
7c6bb84016 | ||
|
|
c54d3ed5e6 | ||
|
|
c7ee4b37a1 | ||
|
|
7b70b266d8 | ||
|
|
8f60d093ba | ||
|
|
dafbe321d2 | ||
|
|
5f84ea63e8 | ||
|
|
843a7ff70c | ||
|
|
a60620dcea | ||
|
|
015f73dc49 |
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||||
|
pause
|
||||||
@@ -67,6 +67,7 @@ jobs:
|
|||||||
mkdir update
|
mkdir update
|
||||||
cp -r ComfyUI/.ci/update_windows/* ./update/
|
cp -r ComfyUI/.ci/update_windows/* ./update/
|
||||||
cp -r ComfyUI/.ci/windows_base_files/* ./
|
cp -r ComfyUI/.ci/windows_base_files/* ./
|
||||||
|
cp -r ComfyUI/.ci/windows_nightly_base_files/* ./
|
||||||
|
|
||||||
echo "call update_comfyui.bat nopause
|
echo "call update_comfyui.bat nopause
|
||||||
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
..\python_embeded\python.exe -s -m pip install --upgrade --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ To use a textual inversion concepts/embeddings in a text prompt put them in the
|
|||||||
|
|
||||||
Use ```--preview-method auto``` to enable previews.
|
Use ```--preview-method auto``` to enable previews.
|
||||||
|
|
||||||
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth) (for SD1.x and SD2.x) and [taesdxl_decoder.pth](https://github.com/madebyollin/taesd/raw/main/taesdxl_decoder.pth) (for SDXL) models and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI to enable high-quality previews.
|
The default installation includes a fast latent preview method that's low-resolution. To enable higher-quality previews with [TAESD](https://github.com/madebyollin/taesd), download the [taesd_decoder.pth, taesdxl_decoder.pth, taesd3_decoder.pth and taef1_decoder.pth](https://github.com/madebyollin/taesd/) and place them in the `models/vae_approx` folder. Once they're installed, restart ComfyUI and launch it with `--preview-method taesd` to enable high-quality previews.
|
||||||
|
|
||||||
## How to use TLS/SSL?
|
## How to use TLS/SSL?
|
||||||
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
Generate a self-signed certificate (not appropriate for shared/production use) and key by running the command: `openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -sha256 -days 3650 -nodes -subj "/C=XX/ST=StateName/L=CityName/O=CompanyName/OU=CompanySectionName/CN=CommonNameOrHostname"`
|
||||||
|
|||||||
@@ -391,7 +391,8 @@ def controlnet_config(sd):
|
|||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||||
|
|
||||||
def controlnet_load_state_dict(control_model, sd):
|
def controlnet_load_state_dict(control_model, sd):
|
||||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||||
@@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
|
|
||||||
def load_controlnet_mmdit(sd):
|
def load_controlnet_mmdit(sd):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||||
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||||
for k in sd:
|
for k in sd:
|
||||||
new_sd[k] = sd[k]
|
new_sd[k] = sd[k]
|
||||||
|
|
||||||
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.cldm.mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SD3()
|
latent_format = comfy.latent_formats.SD3()
|
||||||
@@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
|
|||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data):
|
def load_controlnet_hunyuandit(controlnet_data):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||||
|
|
||||||
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
control_model = comfy.ldm.hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||||
|
|
||||||
latent_format = comfy.latent_formats.SDXL()
|
latent_format = comfy.latent_formats.SDXL()
|
||||||
@@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
def load_controlnet_flux_xlabs(sd):
|
def load_controlnet_flux_xlabs(sd):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
extra_conds = ['y', 'guidance']
|
extra_conds = ['y', 'guidance']
|
||||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
@@ -536,6 +537,7 @@ def load_controlnet(ckpt_path, model=None):
|
|||||||
if manual_cast_dtype is not None:
|
if manual_cast_dtype is not None:
|
||||||
controlnet_config["operations"] = comfy.ops.manual_cast
|
controlnet_config["operations"] = comfy.ops.manual_cast
|
||||||
controlnet_config["dtype"] = unet_dtype
|
controlnet_config["dtype"] = unet_dtype
|
||||||
|
controlnet_config["device"] = comfy.model_management.unet_offload_device()
|
||||||
controlnet_config.pop("out_channels")
|
controlnet_config.pop("out_channels")
|
||||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||||
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
control_model = comfy.cldm.cldm.ControlNet(**controlnet_config)
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||||
|
|
||||||
if txt.dtype == torch.float16:
|
if txt.dtype == torch.float16:
|
||||||
txt = txt.clip(-65504, 65504)
|
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
|
|
||||||
return img, txt
|
return img, txt
|
||||||
|
|
||||||
@@ -233,7 +233,7 @@ class SingleStreamBlock(nn.Module):
|
|||||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||||
x += mod.gate * output
|
x += mod.gate * output
|
||||||
if x.dtype == torch.float16:
|
if x.dtype == torch.float16:
|
||||||
x = x.clip(-65504, 65504)
|
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
192
comfy/lora.py
192
comfy/lora.py
@@ -17,7 +17,10 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.model_base
|
||||||
import logging
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
LORA_CLIP_MAP = {
|
LORA_CLIP_MAP = {
|
||||||
"mlp.fc1": "mlp_fc1",
|
"mlp.fc1": "mlp_fc1",
|
||||||
@@ -322,3 +325,192 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||||
|
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
|
lora_diff *= alpha
|
||||||
|
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||||
|
weight_norm = (
|
||||||
|
weight_calc.transpose(0, 1)
|
||||||
|
.reshape(weight_calc.shape[1], -1)
|
||||||
|
.norm(dim=1, keepdim=True)
|
||||||
|
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||||
|
.transpose(0, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||||
|
if strength != 1.0:
|
||||||
|
weight_calc -= weight
|
||||||
|
weight += strength * (weight_calc)
|
||||||
|
else:
|
||||||
|
weight[:] = weight_calc
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
for p in patches:
|
||||||
|
strength = p[0]
|
||||||
|
v = p[1]
|
||||||
|
strength_model = p[2]
|
||||||
|
offset = p[3]
|
||||||
|
function = p[4]
|
||||||
|
if function is None:
|
||||||
|
function = lambda a: a
|
||||||
|
|
||||||
|
old_weight = None
|
||||||
|
if offset is not None:
|
||||||
|
old_weight = weight
|
||||||
|
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||||
|
|
||||||
|
if strength_model != 1.0:
|
||||||
|
weight *= strength_model
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||||
|
|
||||||
|
if len(v) == 1:
|
||||||
|
patch_type = "diff"
|
||||||
|
elif len(v) == 2:
|
||||||
|
patch_type = v[0]
|
||||||
|
v = v[1]
|
||||||
|
|
||||||
|
if patch_type == "diff":
|
||||||
|
w1 = v[0]
|
||||||
|
if strength != 0.0:
|
||||||
|
if w1.shape != weight.shape:
|
||||||
|
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||||
|
else:
|
||||||
|
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||||
|
elif patch_type == "lora": #lora/locon
|
||||||
|
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||||
|
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||||
|
dora_scale = v[4]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / mat2.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
if v[3] is not None:
|
||||||
|
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
|
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||||
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
|
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||||
|
try:
|
||||||
|
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "lokr":
|
||||||
|
w1 = v[0]
|
||||||
|
w2 = v[1]
|
||||||
|
w1_a = v[3]
|
||||||
|
w1_b = v[4]
|
||||||
|
w2_a = v[5]
|
||||||
|
w2_b = v[6]
|
||||||
|
t2 = v[7]
|
||||||
|
dora_scale = v[8]
|
||||||
|
dim = None
|
||||||
|
|
||||||
|
if w1 is None:
|
||||||
|
dim = w1_b.shape[0]
|
||||||
|
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if w2 is None:
|
||||||
|
dim = w2_b.shape[0]
|
||||||
|
if t2 is None:
|
||||||
|
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
if len(w2.shape) == 4:
|
||||||
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
|
if v[2] is not None and dim is not None:
|
||||||
|
alpha = v[2] / dim
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "loha":
|
||||||
|
w1a = v[0]
|
||||||
|
w1b = v[1]
|
||||||
|
if v[2] is not None:
|
||||||
|
alpha = v[2] / w1b.shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
w2a = v[3]
|
||||||
|
w2b = v[4]
|
||||||
|
dora_scale = v[7]
|
||||||
|
if v[5] is not None: #cp decomposition
|
||||||
|
t1 = v[5]
|
||||||
|
t2 = v[6]
|
||||||
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
|
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
|
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
|
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
elif patch_type == "glora":
|
||||||
|
if v[4] is not None:
|
||||||
|
alpha = v[4] / v[0].shape[0]
|
||||||
|
else:
|
||||||
|
alpha = 1.0
|
||||||
|
|
||||||
|
dora_scale = v[5]
|
||||||
|
|
||||||
|
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
|
try:
|
||||||
|
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||||
|
if dora_scale is not None:
|
||||||
|
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||||
|
else:
|
||||||
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||||
|
else:
|
||||||
|
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||||
|
|
||||||
|
if old_weight is not None:
|
||||||
|
weight = old_weight
|
||||||
|
|
||||||
|
return weight
|
||||||
|
|||||||
@@ -472,9 +472,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
|||||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||||
|
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||||
|
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||||
|
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||||
|
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||||
|
|
||||||
|
|
||||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||||
|
|
||||||
for unet_config in supported_models:
|
for unet_config in supported_models:
|
||||||
matches = True
|
matches = True
|
||||||
|
|||||||
@@ -44,9 +44,14 @@ cpu_state = CPUState.GPU
|
|||||||
|
|
||||||
total_vram = 0
|
total_vram = 0
|
||||||
|
|
||||||
lowvram_available = True
|
|
||||||
xpu_available = False
|
xpu_available = False
|
||||||
|
try:
|
||||||
|
torch_version = torch.version.__version__
|
||||||
|
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
lowvram_available = True
|
||||||
if args.deterministic:
|
if args.deterministic:
|
||||||
logging.info("Using deterministic algorithms for pytorch")
|
logging.info("Using deterministic algorithms for pytorch")
|
||||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||||
@@ -66,10 +71,10 @@ if args.directml is not None:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
if torch.xpu.is_available():
|
_ = torch.xpu.device_count()
|
||||||
xpu_available = True
|
xpu_available = torch.xpu.is_available()
|
||||||
except:
|
except:
|
||||||
pass
|
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if torch.backends.mps.is_available():
|
if torch.backends.mps.is_available():
|
||||||
@@ -189,7 +194,6 @@ VAE_DTYPES = [torch.float32]
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
torch_version = torch.version.__version__
|
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
@@ -321,8 +325,9 @@ class LoadedModel:
|
|||||||
self.model_unload()
|
self.model_unload()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
with torch.no_grad():
|
||||||
|
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||||
|
|
||||||
self.weights_loaded = True
|
self.weights_loaded = True
|
||||||
return self.real_model
|
return self.real_model
|
||||||
@@ -561,7 +566,9 @@ def loaded_models(only_currently_used=False):
|
|||||||
def cleanup_models(keep_clone_weights_loaded=False):
|
def cleanup_models(keep_clone_weights_loaded=False):
|
||||||
to_delete = []
|
to_delete = []
|
||||||
for i in range(len(current_loaded_models)):
|
for i in range(len(current_loaded_models)):
|
||||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
#TODO: very fragile function needs improvement
|
||||||
|
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||||
|
if num_refs <= 2:
|
||||||
if not keep_clone_weights_loaded:
|
if not keep_clone_weights_loaded:
|
||||||
to_delete = [i] + to_delete
|
to_delete = [i] + to_delete
|
||||||
#TODO: find a less fragile way to do this.
|
#TODO: find a less fragile way to do this.
|
||||||
@@ -668,6 +675,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
|
|||||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||||
for dt in supported_dtypes:
|
for dt in supported_dtypes:
|
||||||
if dt == torch.float16 and fp16_supported:
|
if dt == torch.float16 and fp16_supported:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
@@ -883,7 +891,8 @@ def pytorch_attention_flash_attention():
|
|||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
try:
|
||||||
if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5
|
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
@@ -986,16 +995,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if props.major < 6:
|
if props.major < 6:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
fp16_works = False
|
#FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||||
#FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
|
||||||
#when the model doesn't actually fit on the card
|
|
||||||
#TODO: actually test if GP106 and others have the same type of behavior
|
|
||||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||||
for x in nvidia_10_series:
|
for x in nvidia_10_series:
|
||||||
if x in props.name.lower():
|
if x in props.name.lower():
|
||||||
fp16_works = True
|
return True
|
||||||
|
|
||||||
if fp16_works or manual_cast:
|
if manual_cast:
|
||||||
free_model_memory = maximum_vram_for_weights(device)
|
free_model_memory = maximum_vram_for_weights(device)
|
||||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -27,30 +27,10 @@ import math
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.lora
|
||||||
from comfy.types import UnetWrapperFunction
|
from comfy.types import UnetWrapperFunction
|
||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
|
||||||
lora_diff *= alpha
|
|
||||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
|
||||||
weight_norm = (
|
|
||||||
weight_calc.transpose(0, 1)
|
|
||||||
.reshape(weight_calc.shape[1], -1)
|
|
||||||
.norm(dim=1, keepdim=True)
|
|
||||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
|
||||||
.transpose(0, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
|
||||||
if strength != 1.0:
|
|
||||||
weight_calc -= weight
|
|
||||||
weight += strength * (weight_calc)
|
|
||||||
else:
|
|
||||||
weight[:] = weight_calc
|
|
||||||
return weight
|
|
||||||
|
|
||||||
|
|
||||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||||
to = model_options["transformer_options"].copy()
|
to = model_options["transformer_options"].copy()
|
||||||
|
|
||||||
@@ -92,12 +72,11 @@ def wipe_lowvram_weight(m):
|
|||||||
m.bias_function = None
|
m.bias_function = None
|
||||||
|
|
||||||
class LowVramPatch:
|
class LowVramPatch:
|
||||||
def __init__(self, key, model_patcher):
|
def __init__(self, key, patches):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.model_patcher = model_patcher
|
self.patches = patches
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
|
||||||
@@ -329,7 +308,7 @@ class ModelPatcher:
|
|||||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||||
else:
|
else:
|
||||||
temp_weight = weight.to(torch.float32, copy=True)
|
temp_weight = weight.to(torch.float32, copy=True)
|
||||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key)
|
out_weight = comfy.lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||||
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
|
out_weight = comfy.float.stochastic_rounding(out_weight, weight.dtype)
|
||||||
if inplace_update:
|
if inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, out_weight)
|
comfy.utils.copy_to_param(self.model, key, out_weight)
|
||||||
@@ -340,12 +319,21 @@ class ModelPatcher:
|
|||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
load_completely = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
|
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
|
||||||
|
loading.append((comfy.model_management.module_size(m), n, m))
|
||||||
|
|
||||||
|
load_completely = []
|
||||||
|
loading.sort(reverse=True)
|
||||||
|
for x in loading:
|
||||||
|
n = x[1]
|
||||||
|
m = x[2]
|
||||||
|
module_mem = x[0]
|
||||||
|
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
|
|
||||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
lowvram_counter += 1
|
lowvram_counter += 1
|
||||||
@@ -360,13 +348,13 @@ class ModelPatcher:
|
|||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(weight_key)
|
self.patch_weight_to_device(weight_key)
|
||||||
else:
|
else:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
if force_patch_weights:
|
if force_patch_weights:
|
||||||
self.patch_weight_to_device(bias_key)
|
self.patch_weight_to_device(bias_key)
|
||||||
else:
|
else:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@@ -377,9 +365,8 @@ class ModelPatcher:
|
|||||||
wipe_lowvram_weight(m)
|
wipe_lowvram_weight(m)
|
||||||
|
|
||||||
if hasattr(m, "weight"):
|
if hasattr(m, "weight"):
|
||||||
mem_used = comfy.model_management.module_size(m)
|
mem_counter += module_mem
|
||||||
mem_counter += mem_used
|
load_completely.append((module_mem, n, m))
|
||||||
load_completely.append((mem_used, n, m))
|
|
||||||
|
|
||||||
load_completely.sort(reverse=True)
|
load_completely.sort(reverse=True)
|
||||||
for x in load_completely:
|
for x in load_completely:
|
||||||
@@ -428,174 +415,6 @@ class ModelPatcher:
|
|||||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
|
||||||
for p in patches:
|
|
||||||
strength = p[0]
|
|
||||||
v = p[1]
|
|
||||||
strength_model = p[2]
|
|
||||||
offset = p[3]
|
|
||||||
function = p[4]
|
|
||||||
if function is None:
|
|
||||||
function = lambda a: a
|
|
||||||
|
|
||||||
old_weight = None
|
|
||||||
if offset is not None:
|
|
||||||
old_weight = weight
|
|
||||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
|
||||||
|
|
||||||
if strength_model != 1.0:
|
|
||||||
weight *= strength_model
|
|
||||||
|
|
||||||
if isinstance(v, list):
|
|
||||||
v = (self.calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
|
||||||
|
|
||||||
if len(v) == 1:
|
|
||||||
patch_type = "diff"
|
|
||||||
elif len(v) == 2:
|
|
||||||
patch_type = v[0]
|
|
||||||
v = v[1]
|
|
||||||
|
|
||||||
if patch_type == "diff":
|
|
||||||
w1 = v[0]
|
|
||||||
if strength != 0.0:
|
|
||||||
if w1.shape != weight.shape:
|
|
||||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
|
||||||
else:
|
|
||||||
weight += function(strength * comfy.model_management.cast_to_device(w1, weight.device, weight.dtype))
|
|
||||||
elif patch_type == "lora": #lora/locon
|
|
||||||
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
|
||||||
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
|
||||||
dora_scale = v[4]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / mat2.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
if v[3] is not None:
|
|
||||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
|
||||||
mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
|
||||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
|
||||||
try:
|
|
||||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "lokr":
|
|
||||||
w1 = v[0]
|
|
||||||
w2 = v[1]
|
|
||||||
w1_a = v[3]
|
|
||||||
w1_b = v[4]
|
|
||||||
w2_a = v[5]
|
|
||||||
w2_b = v[6]
|
|
||||||
t2 = v[7]
|
|
||||||
dora_scale = v[8]
|
|
||||||
dim = None
|
|
||||||
|
|
||||||
if w1 is None:
|
|
||||||
dim = w1_b.shape[0]
|
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if w2 is None:
|
|
||||||
dim = w2_b.shape[0]
|
|
||||||
if t2 is None:
|
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
|
||||||
if v[2] is not None and dim is not None:
|
|
||||||
alpha = v[2] / dim
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "loha":
|
|
||||||
w1a = v[0]
|
|
||||||
w1b = v[1]
|
|
||||||
if v[2] is not None:
|
|
||||||
alpha = v[2] / w1b.shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
w2a = v[3]
|
|
||||||
w2b = v[4]
|
|
||||||
dora_scale = v[7]
|
|
||||||
if v[5] is not None: #cp decomposition
|
|
||||||
t1 = v[5]
|
|
||||||
t2 = v[6]
|
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
elif patch_type == "glora":
|
|
||||||
if v[4] is not None:
|
|
||||||
alpha = v[4] / v[0].shape[0]
|
|
||||||
else:
|
|
||||||
alpha = 1.0
|
|
||||||
|
|
||||||
dora_scale = v[5]
|
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
|
||||||
|
|
||||||
try:
|
|
||||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
|
||||||
if dora_scale is not None:
|
|
||||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
|
||||||
else:
|
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
|
||||||
except Exception as e:
|
|
||||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
|
||||||
else:
|
|
||||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
|
||||||
|
|
||||||
if old_weight is not None:
|
|
||||||
weight = old_weight
|
|
||||||
|
|
||||||
return weight
|
|
||||||
|
|
||||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||||
if unpatch_weights:
|
if unpatch_weights:
|
||||||
if self.model.model_lowvram:
|
if self.model.model_lowvram:
|
||||||
@@ -664,10 +483,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
if weight_key in self.patches:
|
if weight_key in self.patches:
|
||||||
m.weight_function = LowVramPatch(weight_key, self)
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
if bias_key in self.patches:
|
if bias_key in self.patches:
|
||||||
m.bias_function = LowVramPatch(bias_key, self)
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
patch_counter += 1
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
@@ -695,3 +514,7 @@ class ModelPatcher:
|
|||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
return self.model.device
|
return self.model.device
|
||||||
|
|
||||||
|
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||||
|
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||||
|
|||||||
36
comfy/ops.py
36
comfy/ops.py
@@ -20,31 +20,40 @@ import torch
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||||
if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device):
|
if device is None or weight.device == device:
|
||||||
return weight
|
if not copy:
|
||||||
|
if dtype is None or weight.dtype == dtype:
|
||||||
|
return weight
|
||||||
|
return weight.to(dtype=dtype, copy=copy)
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
r.copy_(weight, non_blocking=non_blocking)
|
r.copy_(weight, non_blocking=non_blocking)
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def cast_to_input(weight, input, non_blocking=False):
|
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||||
if input is not None:
|
if input is not None:
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = input.dtype
|
dtype = input.dtype
|
||||||
|
if bias_dtype is None:
|
||||||
|
bias_dtype = dtype
|
||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
bias = None
|
bias = None
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
has_function = s.bias_function is not None
|
||||||
if s.bias_function is not None:
|
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
bias = s.bias_function(bias)
|
bias = s.bias_function(bias)
|
||||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
|
||||||
if s.weight_function is not None:
|
has_function = s.weight_function is not None
|
||||||
|
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||||
|
if has_function:
|
||||||
weight = s.weight_function(weight)
|
weight = s.weight_function(weight)
|
||||||
return weight, bias
|
return weight, bias
|
||||||
|
|
||||||
@@ -252,7 +261,8 @@ def fp8_linear(self, input):
|
|||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(input.device)
|
||||||
w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t()
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||||
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
scale_input = self.scale_input
|
scale_input = self.scale_input
|
||||||
@@ -263,8 +273,8 @@ def fp8_linear(self, input):
|
|||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||||
|
|
||||||
if self.bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
|
|||||||
@@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
|||||||
def clip_target(self, state_dict={}):
|
def clip_target(self, state_dict={}):
|
||||||
pref = self.text_encoder_key_prefix[0]
|
pref = self.text_encoder_key_prefix[0]
|
||||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||||
|
dtype_t5 = None
|
||||||
if t5_key in state_dict:
|
if t5_key in state_dict:
|
||||||
dtype_t5 = state_dict[t5_key].dtype
|
dtype_t5 = state_dict[t5_key].dtype
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -74,6 +76,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
|
if not self.dynprompt.has_node(node_id):
|
||||||
|
continue
|
||||||
node = self.dynprompt.get_node(node_id)
|
node = self.dynprompt.get_node(node_id)
|
||||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
@@ -87,6 +91,9 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
# This node doesn't exist -- we can't cache it.
|
||||||
|
return [float("NaN")]
|
||||||
node = dynprompt.get_node(node_id)
|
node = dynprompt.get_node(node_id)
|
||||||
class_type = node["class_type"]
|
class_type = node["class_type"]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
@@ -112,6 +119,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
if not dynprompt.has_node(node_id):
|
||||||
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
input_keys = sorted(inputs.keys())
|
input_keys = sorted(inputs.keys())
|
||||||
for key in input_keys:
|
for key in input_keys:
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ class IsChangedCache:
|
|||||||
self.is_changed[node_id] = node["is_changed"]
|
self.is_changed[node_id] = node["is_changed"]
|
||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||||
|
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||||
try:
|
try:
|
||||||
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
is_changed = _map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||||
@@ -449,7 +450,7 @@ class PromptExecutor:
|
|||||||
"current_outputs": list(current_outputs),
|
"current_outputs": list(current_outputs),
|
||||||
}
|
}
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
nodes.interrupt_processing(False)
|
nodes.interrupt_processing(False)
|
||||||
|
|
||||||
@@ -491,6 +492,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||||
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
break
|
break
|
||||||
|
|||||||
@@ -357,6 +357,25 @@ class TestExecution:
|
|||||||
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}"
|
||||||
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node"
|
||||||
|
|
||||||
|
def test_missing_node_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
|
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||||
|
# We have multiple outputs. The first is invalid, but the second is valid
|
||||||
|
g.node("SaveImage", images=mix1.out(0))
|
||||||
|
g.node("SaveImage", images=mix2.out(0))
|
||||||
|
g.remove_node("removeme")
|
||||||
|
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
|
# Add back in the missing node to make sure the error doesn't break the server
|
||||||
|
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
|
client.run(g)
|
||||||
|
|
||||||
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
# Creating the nodes in this specific order previously caused a bug
|
# Creating the nodes in this specific order previously caused a bug
|
||||||
@@ -450,8 +469,8 @@ class TestExecution:
|
|||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
output1 = g.node("SaveImage", images=input1.out(0))
|
||||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
output2 = g.node("SaveImage", images=input1.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result = client.run(g)
|
||||||
images1 = result.get_images(output1)
|
images1 = result.get_images(output1)
|
||||||
@@ -459,3 +478,22 @@ class TestExecution:
|
|||||||
assert len(images1) == 1, "Should have 1 image"
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
|
|
||||||
|
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||||
|
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
|
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||||
|
|
||||||
|
output = g.node("PreviewImage", images=test_node.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images = result.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
assert not result.did_run(test_node), "The execution should have been cached"
|
||||||
|
|||||||
@@ -95,6 +95,31 @@ class TestCustomIsChanged:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
class TestIsChangedWithConstants:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "custom_is_changed"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def custom_is_changed(self, image, value):
|
||||||
|
return (image * value,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, image, value):
|
||||||
|
if image is None:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
return image.mean().item() * value
|
||||||
|
|
||||||
class TestCustomValidation1:
|
class TestCustomValidation1:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
@@ -312,6 +337,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
"TestCustomIsChanged": TestCustomIsChanged,
|
"TestCustomIsChanged": TestCustomIsChanged,
|
||||||
|
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||||
"TestCustomValidation1": TestCustomValidation1,
|
"TestCustomValidation1": TestCustomValidation1,
|
||||||
"TestCustomValidation2": TestCustomValidation2,
|
"TestCustomValidation2": TestCustomValidation2,
|
||||||
"TestCustomValidation3": TestCustomValidation3,
|
"TestCustomValidation3": TestCustomValidation3,
|
||||||
@@ -325,6 +351,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestLazyMixImages": "Lazy Mix Images",
|
"TestLazyMixImages": "Lazy Mix Images",
|
||||||
"TestVariadicAverage": "Variadic Average",
|
"TestVariadicAverage": "Variadic Average",
|
||||||
"TestCustomIsChanged": "Custom IsChanged",
|
"TestCustomIsChanged": "Custom IsChanged",
|
||||||
|
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||||
"TestCustomValidation1": "Custom Validation 1",
|
"TestCustomValidation1": "Custom Validation 1",
|
||||||
"TestCustomValidation2": "Custom Validation 2",
|
"TestCustomValidation2": "Custom Validation 2",
|
||||||
"TestCustomValidation3": "Custom Validation 3",
|
"TestCustomValidation3": "Custom Validation 3",
|
||||||
|
|||||||
@@ -28,6 +28,28 @@ class StubImage:
|
|||||||
elif content == "NOISE":
|
elif content == "NOISE":
|
||||||
return (torch.rand(batch_size, height, width, 3),)
|
return (torch.rand(batch_size, height, width, 3),)
|
||||||
|
|
||||||
|
class StubConstantImage:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||||
|
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||||
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "stub_constant_image"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Stub Nodes"
|
||||||
|
|
||||||
|
def stub_constant_image(self, value, height, width, batch_size):
|
||||||
|
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||||
|
|
||||||
class StubMask:
|
class StubMask:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
@@ -93,12 +115,14 @@ class StubFloat:
|
|||||||
|
|
||||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||||
"StubImage": StubImage,
|
"StubImage": StubImage,
|
||||||
|
"StubConstantImage": StubConstantImage,
|
||||||
"StubMask": StubMask,
|
"StubMask": StubMask,
|
||||||
"StubInt": StubInt,
|
"StubInt": StubInt,
|
||||||
"StubFloat": StubFloat,
|
"StubFloat": StubFloat,
|
||||||
}
|
}
|
||||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||||
"StubImage": "Stub Image",
|
"StubImage": "Stub Image",
|
||||||
|
"StubConstantImage": "Stub Constant Image",
|
||||||
"StubMask": "Stub Mask",
|
"StubMask": "Stub Mask",
|
||||||
"StubInt": "Stub Int",
|
"StubInt": "Stub Int",
|
||||||
"StubFloat": "Stub Float",
|
"StubFloat": "Stub Float",
|
||||||
|
|||||||
Reference in New Issue
Block a user