Compare commits

...

15 Commits

Author SHA1 Message Date
Jedrzej Kosinski
c8037ab667 Initial exploration of weight zipper 2025-03-24 03:34:42 -05:00
comfyanonymous
3b19fc76e3 Allow disabling pe in flux code for some other models. 2025-03-18 05:09:25 -04:00
comfyanonymous
50614f1b79 Fix regression with clip vision. 2025-03-17 13:56:11 -04:00
comfyanonymous
6dc7b0bfe3 Add support for giant dinov2 image encoder. 2025-03-17 05:53:54 -04:00
comfyanonymous
e8e990d6b8 Cleanup code. 2025-03-16 06:29:12 -04:00
Jedrzej Kosinski
2e24a15905 Call unpatch_hooks at the start of ModelPatcher.partially_unload (#7253)
* Call unpatch_hooks at the start of ModelPatcher.partially_unload

* Only call unpatch_hooks in partially_unload if lowvram is possible
2025-03-16 06:02:45 -04:00
chaObserv
fd5297131f Guard the edge cases of noise term in er_sde (#7265) 2025-03-16 06:02:25 -04:00
comfyanonymous
55a1b09ddc Allow loading diffusion model files with the "Load Checkpoint" node. 2025-03-15 08:27:49 -04:00
comfyanonymous
3c3988df45 Show a better error message if the VAE is invalid. 2025-03-15 08:26:36 -04:00
Christian Byrne
7ebd8087ff hotfix fe (#7244) 2025-03-15 01:38:10 -04:00
Chenlei Hu
c624c29d66 Update frontend to 1.12.9 (#7236)
* Update frontend to 1.12.9

* Update requirements.txt
2025-03-14 18:17:26 -04:00
comfyanonymous
a2448fc527 Remove useless code. 2025-03-14 18:10:37 -04:00
comfyanonymous
6a0daa79b6 Make the SkipLayerGuidanceDIT node work on WAN. 2025-03-14 10:55:19 -04:00
FeepingCreature
9c98c6358b Tolerate missing @torch.library.custom_op (#7234)
This can happen on Pytorch versions older than 2.4.
2025-03-14 09:51:26 -04:00
FeepingCreature
7aceb9f91c Add --use-flash-attention flag. (#7223)
* Add --use-flash-attention flag.
This is useful on AMD systems, as FA builds are still 10% faster than Pytorch cross-attention.
2025-03-14 03:22:41 -04:00
17 changed files with 495 additions and 35 deletions

View File

@@ -106,6 +106,7 @@ attn_group.add_argument("--use-split-cross-attention", action="store_true", help
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")

View File

@@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
import comfy.clip_model import comfy.clip_model
import comfy.image_encoders.dino2
class Output: class Output:
def __getitem__(self, key): def __getitem__(self, key):
@@ -34,6 +35,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0 image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1]) return (image - mean.view([3,1,1])) / std.view([3,1,1])
IMAGE_ENCODERS = {
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
}
class ClipVisionModel(): class ClipVisionModel():
def __init__(self, json_config): def __init__(self, json_config):
with open(json_config) as f: with open(json_config) as f:
@@ -42,10 +49,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224) self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711]) self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device() self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device() offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device) self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast) self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@@ -111,6 +119,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else: else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json") json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
elif "embeddings.patch_embeddings.projection.weight" in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else: else:
return None return None

View File

@@ -0,0 +1,141 @@
import torch
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
class Dino2AttentionOutput(torch.nn.Module):
def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
super().__init__()
self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
def forward(self, x):
return self.dense(x)
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
class LayerScale(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
class SwiGLUFFN(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
in_features = out_features = dim
hidden_features = int(dim * 4)
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
x = self.weights_in(x)
x1, x2 = x.chunk(2, dim=-1)
x = torch.nn.functional.silu(x1) * x2
return self.weights_out(x)
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
self.mlp = SwiGLUFFN(dim, dtype, device, operations)
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
def forward(self, x, intermediate_output=None):
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
if intermediate_output < 0:
intermediate_output = len(self.layer) + intermediate_output
intermediate = None
for i, l in enumerate(self.layer):
x = l(x, optimized_attention)
if i == intermediate_output:
intermediate = x.clone()
return x, intermediate
class Dino2PatchEmbeddings(torch.nn.Module):
def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
super().__init__()
self.projection = operations.Conv2d(
in_channels=num_channels,
out_channels=dim,
kernel_size=patch_size,
stride=patch_size,
bias=True,
dtype=dtype,
device=device
)
def forward(self, pixel_values):
return self.projection(pixel_values).flatten(2).transpose(1, 2)
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
super().__init__()
patch_size = 14
image_size = 518
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
def forward(self, pixel_values):
x = self.patch_embeddings(pixel_values)
# TODO: mask_token?
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
return x
class Dinov2Model(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
num_layers = config_dict["num_hidden_layers"]
dim = config_dict["hidden_size"]
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
pooled_output = x[:, 0, :]
return x, i, pooled_output, None

View File

@@ -0,0 +1,21 @@
{
"attention_probs_dropout_prob": 0.0,
"drop_path_rate": 0.0,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 1536,
"image_size": 518,
"initializer_range": 0.02,
"layer_norm_eps": 1e-06,
"layerscale_value": 1.0,
"mlp_ratio": 4,
"model_type": "dinov2",
"num_attention_heads": 24,
"num_channels": 3,
"num_hidden_layers": 40,
"patch_size": 14,
"qkv_bias": true,
"use_swiglu_ffn": true,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225]
}

View File

@@ -1419,6 +1419,6 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
old_denoised_d = denoised_d old_denoised_d = denoised_d
if s_noise != 0 and sigmas[i + 1] > 0: if s_noise != 0 and sigmas[i + 1] > 0:
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt() x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
old_denoised = denoised old_denoised = denoised
return x return x

View File

@@ -10,8 +10,9 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape q_shape = q.shape
k_shape = k.shape k_shape = k.shape
q = q.float().reshape(*q.shape[:-1], -1, 1, 2) if pe is not None:
k = k.float().reshape(*k.shape[:-1], -1, 1, 2) q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
@@ -36,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

View File

@@ -115,8 +115,11 @@ class Flux(nn.Module):
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim]) vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt) txt = self.txt_in(txt)
if img_ids is not None:
ids = torch.cat((txt_ids, img_ids), dim=1) ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids) pe = self.pe_embedder(ids)
else:
pe = None
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks): for i, block in enumerate(self.double_blocks):

View File

@@ -24,6 +24,13 @@ if model_management.sage_attention_enabled():
logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention") logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
exit(-1) exit(-1)
if model_management.flash_attention_enabled():
try:
from flash_attn import flash_attn_func
except ModuleNotFoundError:
logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
exit(-1)
from comfy.cli_args import args from comfy.cli_args import args
import comfy.ops import comfy.ops
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@@ -496,6 +503,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out return out
try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
@flash_attn_wrapper.register_fake
def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
# Output shape is the same as q
return q.new_empty(q.shape)
except AttributeError as error:
FLASH_ATTN_ERROR = error
def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
if mask is not None:
# add a batch dimension if there isn't already one
if mask.ndim == 2:
mask = mask.unsqueeze(0)
# add a heads dimension if there isn't already one
if mask.ndim == 3:
mask = mask.unsqueeze(1)
try:
assert mask is None
out = flash_attn_wrapper(
q.transpose(1, 2),
k.transpose(1, 2),
v.transpose(1, 2),
dropout_p=0.0,
causal=False,
).transpose(1, 2)
except Exception as e:
logging.warning(f"Flash Attention failed, using default SDPA: {e}")
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out
optimized_attention = attention_basic optimized_attention = attention_basic
if model_management.sage_attention_enabled(): if model_management.sage_attention_enabled():
@@ -504,6 +568,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled(): elif model_management.xformers_enabled():
logging.info("Using xformers attention") logging.info("Using xformers attention")
optimized_attention = attention_xformers optimized_attention = attention_xformers
elif model_management.flash_attention_enabled():
logging.info("Using Flash Attention")
optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled(): elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention") logging.info("Using pytorch attention")
optimized_attention = attention_pytorch optimized_attention = attention_pytorch

View File

@@ -384,6 +384,7 @@ class WanModel(torch.nn.Module):
context, context,
clip_fea=None, clip_fea=None,
freqs=None, freqs=None,
transformer_options={},
): ):
r""" r"""
Forward pass through the diffusion model Forward pass through the diffusion model
@@ -423,14 +424,18 @@ class WanModel(torch.nn.Module):
context_clip = self.img_emb(clip_fea) # bs x 257 x dim context_clip = self.img_emb(clip_fea) # bs x 257 x dim
context = torch.concat([context_clip, context], dim=1) context = torch.concat([context_clip, context], dim=1)
# arguments patches_replace = transformer_options.get("patches_replace", {})
kwargs = dict( blocks_replace = patches_replace.get("dit", {})
e=e0, for i, block in enumerate(self.blocks):
freqs=freqs, if ("double_block", i) in blocks_replace:
context=context) def block_wrap(args):
out = {}
for block in self.blocks: out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
x = block(x, **kwargs) return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, e=e0, freqs=freqs, context=context)
# head # head
x = self.head(x, e) x = self.head(x, e)
@@ -439,7 +444,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes) x = self.unpatchify(x, grid_sizes)
return x return x
def forward(self, x, timestep, context, clip_fea=None, **kwargs): def forward(self, x, timestep, context, clip_fea=None, transformer_options={},**kwargs):
bs, c, t, h, w = x.shape bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size patch_size = self.patch_size
@@ -453,7 +458,7 @@ class WanModel(torch.nn.Module):
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs) img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
freqs = self.rope_embedder(img_ids).movedim(1, 2) freqs = self.rope_embedder(img_ids).movedim(1, 2)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs)[:, :, :t, :h, :w] return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes): def unpatchify(self, x, grid_sizes):
r""" r"""

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging import logging
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
@@ -104,7 +105,7 @@ class BaseModel(torch.nn.Module):
self.model_config = model_config self.model_config = model_config
self.manual_cast_dtype = model_config.manual_cast_dtype self.manual_cast_dtype = model_config.manual_cast_dtype
self.device = device self.device = device
self.current_patcher: 'ModelPatcher' = None self.current_patcher: ModelPatcher = None
if not unet_config.get("disable_unet_model_creation", False): if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None: if model_config.custom_operations is None:
@@ -128,6 +129,7 @@ class BaseModel(torch.nn.Module):
logging.info("model_type {}".format(model_type.name)) logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels)) logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor self.memory_usage_factor = model_config.memory_usage_factor
self.zipper_initialized = False
def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor( return comfy.patcher_extension.WrapperExecutor.new_class_executor(
@@ -137,6 +139,16 @@ class BaseModel(torch.nn.Module):
).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) ).execute(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs): def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
# handle lowvram zipper initialization, if required
if self.model_lowvram and not self.zipper_initialized:
if self.current_patcher:
self.zipper_initialized = True
with self.current_patcher.use_ejected():
loading = self.current_patcher._load_list_lowvram_only()
return self._apply_model_inner(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
def _apply_model_inner(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t sigma = t
xc = self.model_sampling.calculate_input(sigma, x) xc = self.model_sampling.calculate_input(sigma, x)
if c_concat is not None: if c_concat is not None:

View File

@@ -930,6 +930,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled(): def sage_attention_enabled():
return args.use_sage_attention return args.use_sage_attention
def flash_attention_enabled():
return args.use_flash_attention
def xformers_enabled(): def xformers_enabled():
global directml_enabled global directml_enabled
global cpu_state global cpu_state

View File

@@ -17,7 +17,7 @@
""" """
from __future__ import annotations from __future__ import annotations
from typing import Optional, Callable from typing import Optional, Callable, TYPE_CHECKING
import torch import torch
import copy import copy
import inspect import inspect
@@ -26,6 +26,7 @@ import uuid
import collections import collections
import math import math
import comfy.ops
import comfy.utils import comfy.utils
import comfy.float import comfy.float
import comfy.model_management import comfy.model_management
@@ -34,6 +35,9 @@ import comfy.hooks
import comfy.patcher_extension import comfy.patcher_extension
from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection from comfy.patcher_extension import CallbacksMP, WrappersMP, PatcherInjection
from comfy.comfy_types import UnetWrapperFunction from comfy.comfy_types import UnetWrapperFunction
if TYPE_CHECKING:
from comfy.model_base import BaseModel
def string_to_seed(data): def string_to_seed(data):
crc = 0xFFFFFFFF crc = 0xFFFFFFFF
@@ -201,7 +205,7 @@ class MemoryCounter:
class ModelPatcher: class ModelPatcher:
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False): def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
self.size = size self.size = size
self.model = model self.model: BaseModel = model
if not hasattr(self.model, 'device'): if not hasattr(self.model, 'device'):
logging.debug("Model doesn't have a device attribute.") logging.debug("Model doesn't have a device attribute.")
self.model.device = offload_device self.model.device = offload_device
@@ -568,6 +572,14 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _zipper_dict_lowvram_only(self):
loading = self._load_list_lowvram_only()
def _load_list_lowvram_only(self):
loading = self._load_list()
return [x for x in loading if hasattr(x[2], "prev_comfy_cast_weights")]
def _load_list(self): def _load_list(self):
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
@@ -583,6 +595,35 @@ class ModelPatcher:
loading.append((comfy.model_management.module_size(m), n, m, params)) loading.append((comfy.model_management.module_size(m), n, m, params))
return loading return loading
def prepare_teeth(self):
ordered_list = self._load_list_lowvram_only()
prev_i = None
next_i = None
# first, create teeth on modules in list
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.init_tooth(self.load_device, self.offload_device, l[1])
# create teeth linked list
for i in range(len(ordered_list)):
if i+1 == len(ordered_list):
next_i = None
else:
next_i = i+1
m: comfy.ops.CastWeightBiasOp = ordered_list[i][2]
if prev_i is not None:
m.zipper_tooth.prev_tooth = ordered_list[prev_i][2].zipper_tooth
else:
m.zipper_tooth.start = True
if next_i is not None:
m.zipper_tooth.next_tooth = ordered_list[next_i][2].zipper_tooth
prev_i = i
def clean_teeth(self):
ordered_list = self._load_list_lowvram_only()
for l in ordered_list:
m: comfy.ops.CastWeightBiasOp = l[2]
m.clean_tooth()
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
with self.use_ejected(): with self.use_ejected():
self.unpatch_hooks() self.unpatch_hooks()
@@ -591,6 +632,8 @@ class ModelPatcher:
lowvram_counter = 0 lowvram_counter = 0
loading = self._load_list() loading = self._load_list()
logging.info(f"total size of _load_list: {sum([x[0] for x in loading])}")
load_completely = [] load_completely = []
loading.sort(reverse=True) loading.sort(reverse=True)
for x in loading: for x in loading:
@@ -672,6 +715,7 @@ class ModelPatcher:
if lowvram_counter > 0: if lowvram_counter > 0:
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter)) logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
else: else:
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load)) logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self.model.model_lowvram = False self.model.model_lowvram = False
@@ -684,6 +728,9 @@ class ModelPatcher:
self.model.model_loaded_weight_memory = mem_counter self.model.model_loaded_weight_memory = mem_counter
self.model.current_weight_patches_uuid = self.patches_uuid self.model.current_weight_patches_uuid = self.patches_uuid
if self.model.model_lowvram:
self.prepare_teeth()
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load) callback(self, device_to, lowvram_model_memory, force_patch_weights, full_load)
@@ -715,6 +762,7 @@ class ModelPatcher:
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
wipe_lowvram_weight(m) wipe_lowvram_weight(m)
self.clean_teeth()
self.model.model_lowvram = False self.model.model_lowvram = False
self.model.lowvram_patch_counter = 0 self.model.lowvram_patch_counter = 0
@@ -747,6 +795,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected(): with self.use_ejected():
hooks_unpatched = False
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = self._load_list() unload_list = self._load_list()
@@ -770,6 +819,10 @@ class ModelPatcher:
move_weight = False move_weight = False
break break
if not hooks_unpatched:
self.unpatch_hooks()
hooks_unpatched = True
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
@@ -799,8 +852,10 @@ class ModelPatcher:
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.zipper_initialized = False
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed self.model.model_loaded_weight_memory -= memory_freed
self.prepare_teeth()
return memory_freed return memory_freed
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):

View File

@@ -16,6 +16,7 @@
along with this program. If not, see <https://www.gnu.org/licenses/>. along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
from __future__ import annotations
import torch import torch
import logging import logging
import comfy.model_management import comfy.model_management
@@ -56,6 +57,79 @@ class CastWeightBiasOp:
comfy_cast_weights = False comfy_cast_weights = False
weight_function = [] weight_function = []
bias_function = [] bias_function = []
zipper_init: dict = None
zipper_tooth: ZipperTooth = None
_zipper_tooth: ZipperTooth = None
def init_tooth(self, load_device, offload_device, key: str=None):
if self.zipper_tooth:
self.clean_tooth()
self.zipper_tooth = ZipperTooth(self, load_device, offload_device, key)
def clean_tooth(self):
if self.zipper_tooth:
del self.zipper_tooth
self.zipper_tooth = None
def connect_teeth(self):
if self.zipper_init is not None:
self.zipper_init[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
self.zipper_dict["prev_zipper_key"] = self.zipper_key
# def zipper_connect(self):
# if self.zipper_dict is not None:
# self.zipper_dict[self.zipper_key] = (hasattr(self, "prev_comfy_cast_weights"), self.zipper_dict.get("prev_zipper_key", None))
# self.zipper_dict["prev_zipper_key"] = self.zipper_key
class ZipperTooth:
def __init__(self, op: CastWeightBiasOp, load_device, offload_device, key: str=None):
self.op = op
self.key: str = key
self.weight_preloaded: torch.Tensor = None
self.bias_preloaded: torch.Tensor = None
self.load_device = load_device
self.offload_device = offload_device
self.start = False
self.prev_tooth: ZipperTooth = None
self.next_tooth: ZipperTooth = None
def get_bias_weight(self, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
try:
if self.start:
return cast_bias_weight(self.op, input, dtype, device, bias_dtype)
return self.weight_preloaded, self.bias_preloaded
finally:
# if self.prev_tooth:
# self.prev_tooth.offload_previous(0)
self.next_tooth.preload_next(0, input, dtype, device, bias_dtype)
def preload_next(self, teeth_count=1, input: torch.Tensor=None, dtype=None, device=None, bias_dtype=None):
# TODO: queue load of tensors
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
non_blocking = comfy.model_management.device_supports_non_blocking(self.load_device)
if self.op.bias is not None:
self.bias_preloaded = comfy.model_management.cast_to(self.op.bias, bias_dtype, device, non_blocking=non_blocking)
self.weight_preloaded = comfy.model_management.cast_to(self.op.weight, dtype, device, non_blocking=non_blocking)
if self.next_tooth and teeth_count:
self.next_tooth.preload_next(teeth_count-1)
def offload_previous(self, teeth_count=1):
# TODO: queue offload of tensors
self.weight_preloaded = None
self.bias_preloaded = None
if self.prev_tooth and teeth_count:
self.prev_tooth.offload_previous(teeth_count-1)
class disable_weight_init: class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -63,6 +137,10 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
#if self.zipper_init:
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
@@ -77,6 +155,9 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
@@ -91,6 +172,9 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
@@ -105,6 +189,9 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias) return self._conv_forward(input, weight, bias)
@@ -119,6 +206,9 @@ class disable_weight_init:
return None return None
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
@@ -134,6 +224,9 @@ class disable_weight_init:
def forward_comfy_cast_weights(self, input): def forward_comfy_cast_weights(self, input):
if self.weight is not None: if self.weight is not None:
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
else: else:
weight = None weight = None
@@ -156,6 +249,9 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose2d( return torch.nn.functional.conv_transpose2d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
@@ -177,6 +273,9 @@ class disable_weight_init:
input, output_size, self.stride, self.padding, self.kernel_size, input, output_size, self.stride, self.padding, self.kernel_size,
num_spatial_dims, self.dilation) num_spatial_dims, self.dilation)
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(input)
else:
weight, bias = cast_bias_weight(self, input) weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.conv_transpose1d( return torch.nn.functional.conv_transpose1d(
input, weight, bias, self.stride, self.padding, input, weight, bias, self.stride, self.padding,
@@ -197,6 +296,9 @@ class disable_weight_init:
output_dtype = out_dtype output_dtype = out_dtype
if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16:
out_dtype = None out_dtype = None
if self.zipper_tooth:
weight, bias = self.zipper_tooth.get_bias_weight(device=input.device, dtype=out_dtype)
else:
weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype)
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)

View File

@@ -6,6 +6,7 @@ if TYPE_CHECKING:
from comfy.model_patcher import ModelPatcher from comfy.model_patcher import ModelPatcher
from comfy.model_base import BaseModel from comfy.model_base import BaseModel
from comfy.controlnet import ControlBase from comfy.controlnet import ControlBase
from comfy.ops import CastWeightBiasOp
import torch import torch
from functools import partial from functools import partial
import collections import collections
@@ -18,6 +19,7 @@ import comfy.patcher_extension
import comfy.hooks import comfy.hooks
import scipy.stats import scipy.stats
import numpy import numpy
import comfy.ops
def add_area_dims(area, num_dims): def add_area_dims(area, num_dims):
@@ -360,15 +362,38 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
#The main sampling function shared by all the samplers #The main sampling function shared by all the samplers
#Returns denoised #Returns denoised
def sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None): def sampling_function(model: BaseModel, x, timestep, uncond, cond, cond_scale, model_options={}, seed=None):
if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False: if math.isclose(cond_scale, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
uncond_ = None uncond_ = None
else: else:
uncond_ = uncond uncond_ = uncond
do_cleanup = False
if "weight_zipper" not in model_options:
do_cleanup = True
#zipper_dict = {}
model_options["weight_zipper"] = True
loaded_modules = model.current_patcher._load_list_lowvram_only()
low_m = [x for x in loaded_modules if hasattr(x[2], "prev_comfy_cast_weights")]
sum_m = sum([x[0] for x in low_m])
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(m, "comfy_cast_weights"):
m.zipper_tooth = comfy.ops.ZipperTooth
#m.zipper_dict = zipper_dict
m.zipper_key = l[1]
conds = [cond, uncond_] conds = [cond, uncond_]
out = calc_cond_batch(model, conds, x, timestep, model_options) out = calc_cond_batch(model, conds, x, timestep, model_options)
if do_cleanup:
zzz = 20
for l in loaded_modules:
m: CastWeightBiasOp = l[2]
if hasattr(l[2], "comfy_cast_weights"):
#m.zipper_dict = None
m.zipper_key = None
for fn in model_options.get("sampler_pre_cfg_function", []): for fn in model_options.get("sampler_pre_cfg_function", []):
args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep, args = {"conds":conds, "conds_out": out, "cond_scale": cond_scale, "timestep": timestep,
"input": x, "sigma": timestep, "model": model, "model_options": model_options} "input": x, "sigma": timestep, "model": model, "model_options": model_options}

View File

@@ -440,6 +440,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def throw_exception_if_invalid(self):
if self.first_stage_model is None:
raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode() downscale_ratio = self.spacial_compression_encode()
@@ -495,6 +499,7 @@ class VAE:
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in): def decode(self, samples_in):
self.throw_exception_if_invalid()
pixel_samples = None pixel_samples = None
try: try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
@@ -525,6 +530,7 @@ class VAE:
return pixel_samples return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used) model_management.load_models_gpu([self.patcher], memory_required=memory_used)
dims = samples.ndim - 2 dims = samples.ndim - 2
@@ -553,6 +559,7 @@ class VAE:
return output.movedim(1, -1) return output.movedim(1, -1)
def encode(self, pixel_samples): def encode(self, pixel_samples):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
if self.latent_dim == 3 and pixel_samples.ndim < 5: if self.latent_dim == 3 and pixel_samples.ndim < 5:
@@ -585,6 +592,7 @@ class VAE:
return samples return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples) pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1) pixel_samples = pixel_samples.movedim(-1, 1)
@@ -899,7 +907,12 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None: if model_config is None:
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
if diffusion_model is None:
return None return None
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
unet_weight_dtype = list(model_config.supported_inference_dtypes) unet_weight_dtype = list(model_config.supported_inference_dtypes)
if model_config.scaled_fp8 is not None: if model_config.scaled_fp8 is not None:

View File

@@ -770,6 +770,7 @@ class VAELoader:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) vae = comfy.sd.VAE(sd=sd)
vae.throw_exception_if_invalid()
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:

View File

@@ -1,4 +1,4 @@
comfyui-frontend-package==1.11.8 comfyui-frontend-package==1.12.14
torch torch
torchsde torchsde
torchvision torchvision