Compare commits

..

11 Commits

Author SHA1 Message Date
comfyanonymous
4c82741b54 Support official SD3.5 Controlnets. 2024-11-26 11:31:25 -05:00
comfyanonymous
15c39ea757 Support for the official mochi lora format. 2024-11-26 03:34:36 -05:00
comfyanonymous
b7143b74ce Flux inpaint model does not work in fp16. 2024-11-26 01:33:01 -05:00
comfyanonymous
61196d8857 Add option to inference the diffusion model in fp32 and fp64. 2024-11-25 05:00:23 -05:00
comfyanonymous
b4526d3fc3 Skip layer guidance now works on hydit model. 2024-11-24 05:54:30 -05:00
40476
3d802710e7 Update README.md (#5707) 2024-11-24 04:12:07 -05:00
spacepxl
7126ecffde set LTX min length to 1 for t2i (#5750)
At length=1, the LTX model can do txt2img and img2img with no other changes required.
2024-11-23 21:33:08 -05:00
comfyanonymous
ab885b33ba Skip layer guidance node now works on LTX-Video. 2024-11-23 10:33:05 -05:00
comfyanonymous
839ed3368e Some improvements to the lowvram unloading. 2024-11-22 20:59:15 -05:00
comfyanonymous
6e8cdcd3cb Fix some tiled VAE decoding issues with LTX-Video. 2024-11-22 18:00:34 -05:00
comfyanonymous
e5c3f4b87f LTXV lowvram fixes. 2024-11-22 17:17:11 -05:00
16 changed files with 368 additions and 92 deletions

View File

@@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation | | Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------| |------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation | | `Ctrl` + `Enter` | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation | | `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation | | `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo | | `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| Ctrl + S | Save workflow | | `Ctrl` + `S` | Save workflow |
| Ctrl + O | Load workflow | | `Ctrl` + `O` | Load workflow |
| Ctrl + A | Select all nodes | | `Ctrl` + `A` | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes | | `Alt `+ `C` | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes | | `Ctrl` + `M` | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) | | `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes | | `Delete`/`Backspace` | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph | | `Ctrl` + `Backspace` | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor | | `Space` | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection | | `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) | | `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) | | `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time | | `Shift` + `Drag` | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph | | `Ctrl` + `D` | Load default graph |
| Alt + `+` | Canvas Zoom in | | `Alt` + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out | | `Alt` + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out | | `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes | | `P` | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes | | `Ctrl` + `G` | Group selected nodes |
| Q | Toggle visibility of the queue | | `Q` | Toggle visibility of the queue |
| H | Toggle visibility of history | | `H` | Toggle visibility of history |
| R | Refresh graph | | `R` | Refresh graph |
| Double-Click LMB | Open node quick search palette | | Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once | | `Shift` + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot | | `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users `Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing # Installing

122
comfy/cldm/dit_embedder.py Normal file
View File

@@ -0,0 +1,122 @@
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.") fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group() fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.") fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.") fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.") fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.") fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

View File

@@ -35,7 +35,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
def broadcast_image_to(tensor, target_batch_size, batched_number): def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0] current_batch_size = tensor.shape[0]
@@ -78,6 +78,7 @@ class ControlBase:
self.concat_mask = False self.concat_mask = False
self.extra_concat_orig = [] self.extra_concat_orig = []
self.extra_concat = None self.extra_concat = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ class ControlBase:
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy() c.extra_concat_orig = self.extra_concat_orig.copy()
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@@ -181,7 +183,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__() super().__init__()
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@@ -196,6 +198,7 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds self.extra_conds += extra_conds
self.strength_type = strength_type self.strength_type = strength_type
self.concat_mask = concat_mask self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@@ -224,6 +227,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None: if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.") raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None: if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True) loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1)) self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model return control_model
def load_controlnet_mmdit(sd, model_options={}): def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "") new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
return control return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
print(control_type, canny_cnet, depth_cnet)
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}): def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
@@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options) return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data: elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data: elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux

View File

@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
style=None, style=None,
return_dict=False, return_dict=False,
control=None, control=None,
transformer_options=None, transformer_options={},
): ):
""" """
Forward pass of the encoder. Forward pass of the encoder.
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool return_dict: bool
Whether to return a dictionary. Whether to return a dictionary.
""" """
#import pdb patches_replace = transformer_options.get("patches_replace", {})
#pdb.set_trace()
encoder_hidden_states = context encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024 text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors # Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D] c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None controls = None
if control: if control:
controls = control.get("output", None) controls = control.get("output", None)
@@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype) skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else: else:
skip = skips.pop() skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else: else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D) skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1): if layer < (self.depth // 2 - 1):
skips.append(x) skips.append(x)

View File

@@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype)) self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None): def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
@@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1) self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs): def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid( indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2], orig_num_frames=x.shape[2],
orig_height=x.shape[3], orig_height=x.shape[3],
orig_width=x.shape[4], orig_width=x.shape[4],
batch_size=x.shape[0], batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device, device=x.device,
) )
@@ -468,18 +470,28 @@ class LTXVModel(torch.nn.Module):
batch_size, -1, x.shape[-1] batch_size, -1, x.shape[-1]
) )
for block in self.transformer_blocks: blocks_replace = patches_replace.get("dit", {})
x = block( for i, block in enumerate(self.transformer_blocks):
x, if ("double_block", i) in blocks_replace:
context=context, def block_wrap(args):
attention_mask=attention_mask, out = {}
timestep=timestep, out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
pe=pe return out
)
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output # 3. Output
scale_shift_values = ( scale_shift_values = (
self.scale_shift_table[None, None] + embedded_timestep[:, :, None] self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
) )
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x) x = self.norm_out(x)

View File

@@ -2,6 +2,8 @@ from typing import Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module): class CausalConv3d(nn.Module):
@@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
width_pad = kernel_size[2] // 2 width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad) padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d( self.conv = ops.Conv3d(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,

View File

@@ -628,10 +628,10 @@ class processor(nn.Module):
self.register_buffer("channel", torch.empty(128)) self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x): def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1) return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x): def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1) return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
def __init__(self): def __init__(self):

View File

@@ -4,7 +4,8 @@ import torch
from .dual_conv3d import DualConv3d from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd( def make_conv_nd(
dims: Union[int, Tuple[int, int]], dims: Union[int, Tuple[int, int]],
@@ -19,7 +20,7 @@ def make_conv_nd(
causal=False, causal=False,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@@ -41,7 +42,7 @@ def make_conv_nd(
groups=groups, groups=groups,
bias=bias, bias=bias,
) )
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, in_channels=in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
@@ -71,11 +72,11 @@ def make_linear_nd(
bias=True, bias=True,
): ):
if dims == 2: if dims == 2:
return torch.nn.Conv2d( return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
elif dims == 3 or dims == (2, 1): elif dims == 3 or dims == (2, 1):
return torch.nn.Conv3d( return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
) )
else: else:

View File

@@ -62,6 +62,7 @@ def load_lora(lora, to_load):
diffusers_lora = "{}_lora.up.weight".format(x) diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x) diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x) diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x) transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None A_name = None
@@ -81,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x) B_name = "{}.lora.down.weight".format(x)
mid_name = None mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys(): elif transformers_lora in lora.keys():
A_name = transformers_lora A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x) B_name ="{}.lora_linear_layer.down.weight".format(x)
@@ -362,6 +367,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map return key_map

View File

@@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0: if model_params < 0:
model_params = 1000000000000000000000 model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:
@@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
# None means no manual cast # None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]): def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32: if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False) fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)

View File

@@ -367,10 +367,7 @@ class ModelPatcher:
else: else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): def _load_list(self):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = [] loading = []
for n, m in self.model.named_modules(): for n, m in self.model.named_modules():
params = [] params = []
@@ -383,6 +380,13 @@ class ModelPatcher:
break break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params)) loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = self._load_list()
load_completely = [] load_completely = []
loading.sort(reverse=True) loading.sort(reverse=True)
@@ -514,14 +518,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0): def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0 memory_freed = 0
patch_counter = 0 patch_counter = 0
unload_list = [] unload_list = self._load_list()
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
unload_list.append((module_mem, n, m))
unload_list.sort() unload_list.sort()
for unload in unload_list: for unload in unload_list:
if memory_to_free < memory_freed: if memory_to_free < memory_freed:
@@ -529,32 +526,42 @@ class ModelPatcher:
module_mem = unload[0] module_mem = unload[0]
n = unload[1] n = unload[1]
m = unload[2] m = unload[2]
weight_key = "{}.weight".format(n) params = unload[3]
bias_key = "{}.bias".format(n)
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]: move_weight = True
for param in params:
key = "{}.{}".format(n, param)
bk = self.backup.get(key, None) bk = self.backup.get(key, None)
if bk is not None: if bk is not None:
if not lowvram_possible:
move_weight = False
break
if bk.inplace_update: if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight) comfy.utils.copy_to_param(self.model, key, bk.weight)
else: else:
comfy.utils.set_attr_param(self.model, key, bk.weight) comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key) self.backup.pop(key)
m.to(device_to) weight_key = "{}.weight".format(n)
if weight_key in self.patches: bias_key = "{}.bias".format(n)
m.weight_function = LowVramPatch(weight_key, self.patches) if move_weight:
patch_counter += 1 m.to(device_to)
if bias_key in self.patches: if lowvram_possible:
m.bias_function = LowVramPatch(bias_key, self.patches) if weight_key in self.patches:
patch_counter += 1 m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True m.comfy_cast_weights = True
m.comfy_patched_weights = False m.comfy_patched_weights = False
memory_freed += module_mem memory_freed += module_mem
logging.debug("freed {}".format(n)) logging.debug("freed {}".format(n))
self.model.model_lowvram = True self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter self.model.lowvram_patch_counter += patch_counter

View File

@@ -269,7 +269,7 @@ class VAE:
self.latent_dim = 3 self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = 8 self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
else: else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.") logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
@@ -370,7 +370,9 @@ class VAE:
elif dims == 2: elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in) pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3: elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in) tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples return pixel_samples
@@ -434,6 +436,12 @@ class VAE:
def get_sd(self): def get_sd(self):
return self.first_stage_model.state_dict() return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
class StyleModel: class StyleModel:
def __init__(self, model, device="cpu"): def __init__(self, model, device="cpu"):
self.model = model self.model = model

View File

@@ -659,6 +659,15 @@ class Flux(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux): class FluxSchnell(Flux):
unet_config = { unet_config = {
"image_model": "flux", "image_model": "flux",
@@ -731,6 +740,6 @@ class LTXV(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV] models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}), "height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}), "length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}} "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",) RETURN_TYPES = ("LATENT",)
FUNCTION = "generate" FUNCTION = "generate"

View File

@@ -301,7 +301,8 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64): def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4: if tile_size < overlap * 4:
overlap = tile_size // 4 overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8) compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, ) return (images, )