Compare commits

...

15 Commits

Author SHA1 Message Date
Jedrzej Kosinski
c8037ab667 Initial exploration of weight zipper 2025-03-24 03:34:42 -05:00
comfyanonymous
3b19fc76e3 Allow disabling pe in flux code for some other models. 2025-03-18 05:09:25 -04:00
comfyanonymous
50614f1b79 Fix regression with clip vision. 2025-03-17 13:56:11 -04:00
comfyanonymous
6dc7b0bfe3 Add support for giant dinov2 image encoder. 2025-03-17 05:53:54 -04:00
comfyanonymous
e8e990d6b8 Cleanup code. 2025-03-16 06:29:12 -04:00
Jedrzej Kosinski
2e24a15905 Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253)
* Call unpatch_hooks at the start of ModelPatcher.partially_unload

* Only call unpatch_hooks in partially_unload if lowvram is possible
2025-03-16 06:02:45 -04:00
chaObserv
fd5297131f Guard the edge cases of noise term in er_sde (#7265) 2025-03-16 06:02:25 -04:00
comfyanonymous
55a1b09ddc Allow loading diffusion model files with the "Load Checkpoint" node. 2025-03-15 08:27:49 -04:00
comfyanonymous
3c3988df45 Show a better error message if the VAE is invalid. 2025-03-15 08:26:36 -04:00
Christian Byrne
7ebd8087ff hotfix fe (#7244) 2025-03-15 01:38:10 -04:00
Chenlei Hu
c624c29d66 Update frontend to 1.12.9 (#7236)
* Update frontend to 1.12.9

* Update requirements.txt
2025-03-14 18:17:26 -04:00
comfyanonymous
a2448fc527 Remove useless code. 2025-03-14 18:10:37 -04:00
comfyanonymous
6a0daa79b6 Make the SkipLayerGuidanceDIT node work on WAN. 2025-03-14 10:55:19 -04:00
FeepingCreature
9c98c6358b Tolerate missing @torch.library.custom_op (#7234)
This can happen on Pytorch versions older than 2.4.
2025-03-14 09:51:26 -04:00
FeepingCreature
7aceb9f91c Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
2025-03-14 03:22:41 -04:00
17 changed files with 495 additions and 35 deletions

View File

@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")

View File

@@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
@@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None

View File

@@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -1419,6 +1419,6 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt()
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised
return x

View File

@@ -10,10 +10,11 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):

View File

@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args
import comfy.ops
ops = comfy.ops.disable_weight_init
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch

View File

@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
context,
clip_fea=None,
freqs=None,
transformer_options={},
):
r"""
Forward pass through the diffusion model
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1)
# arguments
kwargs = dict(
e=e0,
freqs=freqs,
context=context)
for block in self.blocks:
x = block(x, **kwargs)
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head
x = self.head(x, e)
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs):
def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w]
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
r"""

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device
self.current_patcher: 'ModelPatcher' = None
self.current_patcher: ModelPatcher = None
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor
self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# handle lowvram zipper initialization, if required
if self.model_lowvram and not self.zipper_initialized:
if self.current_patcher:
self.zipper_initialized = True
with self.current_patcher.use_ejected():
loading = self.current_patcher._load_list_lowvram_only()
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None:

View File

@@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled():
global directml_enabled
global cpu_state

View File

@@ -17,7 +17,7 @@
"""
from __future__ import annotations
from typing import Optional, Callable
from typing import Optional, Callable, TYPE_CHECKING
import torch
import copy
import inspect
@@ -26,6 +26,7 @@ import uuid
import collections
import math
import comfy.ops
import comfy.utils
import comfy.float
import comfy.model_management
@@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data):
crc = 0xFFFFFFFF
@@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size
self.model = model
self.model: BaseModel = model
if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device
@@ -568,6 +572,14 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
@@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected():
self.unpatch_hooks()
@@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0
loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = []
loading.sort(reverse=True)
for x in loading:
@@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
self.model.zipper_initialized = False
else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False
@@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0
@@ -747,6 +795,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
@@ -770,6 +819,10 @@ class ModelPatcher:
move_weight = False
break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
@@ -799,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
from __future__ import annotations
import torch
import logging
import comfy.model_management
@@ -56,6 +57,79 @@ class CastWeightBiasOp:
comfy_cast_weights = False
weight_function = []
bias_function = []
zipper_init: dict = None
zipper_tooth: ZipperTooth = None
_zipper_tooth: ZipperTooth = None
def init_tooth(self, load_device, offload_device, key: str=None):
if self.zipper_tooth:
self.clean_tooth()
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
def clean_tooth(self):
if self.zipper_tooth:
del self.zipper_tooth
self.zipper_tooth = None
def connect_teeth(self):
if self.zipper_init is not None:
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
self.zipper_dict["prev_zipper_key"] = self.zipper_key
# def zipper_connect(self):
# if self.zipper_dict is not None:
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
class ZipperTooth:
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
self.op = op
self.key: str = key
self.weight_preloaded: torch.Tensor = None
self.bias_preloaded: torch.Tensor = None
self.load_device = load_device
self.offload_device = offload_device
self.start = False
self.prev_tooth: ZipperTooth = None
self.next_tooth: ZipperTooth = None
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
try:
if self.start:
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
return self.weight_preloaded, self.bias_preloaded
finally:
# if self.prev_tooth:
# self.prev_tooth.offload_previous(0)
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
# TODO: queue load of tensors
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
if self.op.bias is not None:
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
if self.next_tooth and teeth_count:
self.next_tooth.preload_next(teeth_count-1)
def offload_previous(self, teeth_count=1):
# TODO: queue offload of tensors
self.weight_preloaded = None
self.bias_preloaded = None
if self.prev_tooth and teeth_count:
self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -63,7 +137,11 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
#if self.zipper_init:
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -77,7 +155,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -91,7 +172,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -105,7 +189,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
@@ -119,7 +206,10 @@ class disable_weight_init:
return None
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
@@ -134,7 +224,10 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input):
if self.weight is not None:
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
else:
weight = None
bias = None
@@ -156,7 +249,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@@ -177,7 +273,10 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation)
weight, bias = cast_bias_weight(self, input)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding,
output_padding, self.groups, self.dilation)
@@ -197,7 +296,10 @@ class disable_weight_init:
output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
else:
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase
from comfy.ops import CastWeightBiasOp
import torch
from functools import partial
import collections
@@ -18,6 +19,7 @@ import comfy.patcher_extension
import comfy.hooks
import scipy.stats
import numpy
import comfy.ops
def add_area_dims(area, num_dims):
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers
#Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None
else:
uncond_ = uncond
do_cleanup = False
if "weight_zipper" not in model_options:
do_cleanup = True
#zipper_dict = {}
model_options["weight_zipper"] = True
loaded_modules = model.current_patcher._load_list_lowvram_only()
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
sum_m = sum([x[0] for x in low_m])
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(m, "comfy_cast_weights"):
m.zipper_tooth = comfy.ops.ZipperTooth
#m.zipper_dict = zipper_dict
m.zipper_key = l[1]
conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options)
if do_cleanup:
zzz = 20
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(l[2], "comfy_cast_weights"):
#m.zipper_dict = None
m.zipper_key = None
for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options}

View File

@@ -440,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@@ -495,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@@ -525,6 +530,7 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2
@@ -553,6 +559,7 @@ class VAE:
return output.movedim(1, -1)
def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5:
@@ -585,6 +592,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
return None
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None:

View File

@@ -770,6 +770,7 @@ class VAELoader:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,)
class ControlNetLoader:

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.11.8
comfyui-frontend-package==1.12.14
torch
torchsde
torchvision