Compare commits

..

21 Commits

Author SHA1 Message Date
comfyanonymous
4c82741b54 Support official SD3.5 Controlnets. 2024-11-26 11:31:25 -05:00
comfyanonymous
15c39ea757 Support for the official mochi lora format. 2024-11-26 03:34:36 -05:00
comfyanonymous
b7143b74ce Flux inpaint model does not work in fp16. 2024-11-26 01:33:01 -05:00
comfyanonymous
61196d8857 Add option to inference the diffusion model in fp32 and fp64. 2024-11-25 05:00:23 -05:00
comfyanonymous
b4526d3fc3 Skip layer guidance now works on hydit model. 2024-11-24 05:54:30 -05:00
40476
3d802710e7 Update README.md (#5707) 2024-11-24 04:12:07 -05:00
spacepxl
7126ecffde set LTX min length to 1 for t2i (#5750)
At length=1, the LTX model can do txt2img and img2img with no other changes required.
2024-11-23 21:33:08 -05:00
comfyanonymous
ab885b33ba Skip layer guidance node now works on LTX-Video. 2024-11-23 10:33:05 -05:00
comfyanonymous
839ed3368e Some improvements to the lowvram unloading. 2024-11-22 20:59:15 -05:00
comfyanonymous
6e8cdcd3cb Fix some tiled VAE decoding issues with LTX-Video. 2024-11-22 18:00:34 -05:00
comfyanonymous
e5c3f4b87f LTXV lowvram fixes. 2024-11-22 17:17:11 -05:00
comfyanonymous
bc6be6c11e Some fixes to the lowvram system. 2024-11-22 16:40:04 -05:00
comfyanonymous
94323a26a7 Remove prints. 2024-11-22 10:51:31 -05:00
comfyanonymous
5818f6cf51 Remove print. 2024-11-22 10:49:15 -05:00
comfyanonymous
0b734de449 Add LTX-Video support to the Readme. 2024-11-22 09:24:20 -05:00
comfyanonymous
5e16f1d24b Support Lightricks LTX-Video model. 2024-11-22 08:46:39 -05:00
comfyanonymous
2fd9c1308a Fix mask issue in some attention functions. 2024-11-22 02:10:09 -05:00
comfyanonymous
8f0009aad0 Support new flux model variants. 2024-11-21 08:38:23 -05:00
comfyanonymous
41444b5236 Add some new weight patching functionality.
Add a way to reshape lora weights.

Allow weight patches to all weight not just .weight and .bias

Add a way for a lora to set a weight to a specific value.
2024-11-21 07:19:17 -05:00
comfyanonymous
772e620e32 Update readme. 2024-11-20 20:42:51 -05:00
comfyanonymous
07f6eeaa13 Fix mask issue with attention_xformers. 2024-11-20 17:07:46 -05:00
31 changed files with 2481 additions and 108 deletions

View File

@@ -39,6 +39,7 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- Asynchronous Queue system
@@ -74,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
| Keybind | Explanation |
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
| Ctrl + A | Select all nodes |
| Alt + C | Collapse/uncollapse selected nodes |
| Ctrl + M | Mute/unmute selected nodes |
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| Delete/Backspace | Delete selected nodes |
| Ctrl + Backspace | Delete the current graph |
| Space | Move the canvas around when held and moving the cursor |
| Ctrl/Shift + Click | Add clicked node to selection |
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| Shift + Drag | Move multiple selected nodes at the same time |
| Ctrl + D | Load default graph |
| Alt + `+` | Canvas Zoom in |
| Alt + `-` | Canvas Zoom out |
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
| P | Pin/Unpin selected nodes |
| Ctrl + G | Group selected nodes |
| Q | Toggle visibility of the queue |
| H | Toggle visibility of history |
| R | Refresh graph |
| `Ctrl` + `Enter` | Queue up current graph for generation |
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
| `Ctrl` + `S` | Save workflow |
| `Ctrl` + `O` | Load workflow |
| `Ctrl` + `A` | Select all nodes |
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
| `Ctrl` + `M` | Mute/unmute selected nodes |
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
| `Delete`/`Backspace` | Delete selected nodes |
| `Ctrl` + `Backspace` | Delete the current graph |
| `Space` | Move the canvas around when held and moving the cursor |
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
| `Ctrl` + `D` | Load default graph |
| `Alt` + `+` | Canvas Zoom in |
| `Alt` + `-` | Canvas Zoom out |
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
| `P` | Pin/Unpin selected nodes |
| `Ctrl` + `G` | Group selected nodes |
| `Q` | Toggle visibility of the queue |
| `H` | Toggle visibility of history |
| `R` | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
| `Shift` + Drag | Move multiple wires at once |
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
`Ctrl` can also be replaced with `Cmd` instead for macOS users
# Installing
@@ -140,7 +141,7 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1```
```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:

122
comfy/cldm/dit_embedder.py Normal file
View File

@@ -0,0 +1,122 @@
import math
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
class ControlNetEmbedder(nn.Module):
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
attention_head_dim: int,
num_attention_heads: int,
adm_in_channels: int,
num_layers: int,
main_model_double: int,
double_y_emb: bool,
device: torch.device,
dtype: torch.dtype,
pos_embed_max_size: Optional[int] = None,
operations = None,
):
super().__init__()
self.main_model_double = main_model_double
self.dtype = dtype
self.hidden_size = num_attention_heads * attention_head_dim
self.patch_size = patch_size
self.x_embedder = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=pos_embed_max_size is None,
device=device,
dtype=dtype,
operations=operations,
)
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
self.double_y_emb = double_y_emb
if self.double_y_emb:
self.orig_y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.y_embedder = VectorEmbedder(
self.hidden_size, self.hidden_size, dtype, device, operations=operations
)
else:
self.y_embedder = VectorEmbedder(
adm_in_channels, self.hidden_size, dtype, device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
DismantledBlock(
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(num_layers)
)
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
# TODO double check this logic when 8b
self.use_y_embedder = True
self.controlnet_blocks = nn.ModuleList([])
for _ in range(len(self.transformer_blocks)):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=self.hidden_size,
strict_img_size=False,
device=device,
dtype=dtype,
operations=operations,
)
def forward(
self,
x: torch.Tensor,
timesteps: torch.Tensor,
y: Optional[torch.Tensor] = None,
context: Optional[torch.Tensor] = None,
hint = None,
) -> Tuple[Tensor, List[Tensor]]:
x_shape = list(x.shape)
x = self.x_embedder(x)
if not self.double_y_emb:
h = (x_shape[-2] + 1) // self.patch_size
w = (x_shape[-1] + 1) // self.patch_size
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
c = self.t_embedder(timesteps, dtype=x.dtype)
if y is not None and self.y_embedder is not None:
if self.double_y_emb:
y = self.orig_y_embedder(y)
y = self.y_embedder(y)
c = c + y
x = x + self.pos_embed_input(hint)
block_out = ()
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
for i in range(len(self.transformer_blocks)):
out = self.transformer_blocks[i](x, c)
if not self.double_y_emb:
x = out
block_out += (self.controlnet_blocks[i](out),) * repeat
return {"output": block_out}

View File

@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
fpunet_group = parser.add_mutually_exclusive_group()
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")

View File

@@ -23,6 +23,7 @@ class CLIPAttention(torch.nn.Module):
ACTIVATIONS = {"quick_gelu": lambda a: a * torch.sigmoid(1.702 * a),
"gelu": torch.nn.functional.gelu,
"gelu_pytorch_tanh": lambda a: torch.nn.functional.gelu(a, approximate="tanh"),
}
class CLIPMLP(torch.nn.Module):
@@ -139,27 +140,35 @@ class CLIPTextModel(torch.nn.Module):
class CLIPVisionEmbeddings(torch.nn.Module):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, dtype=None, device=None, operations=None):
def __init__(self, embed_dim, num_channels=3, patch_size=14, image_size=224, model_type="", dtype=None, device=None, operations=None):
super().__init__()
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
num_patches = (image_size // patch_size) ** 2
if model_type == "siglip_vision_model":
self.class_embedding = None
patch_bias = True
else:
num_patches = num_patches + 1
self.class_embedding = torch.nn.Parameter(torch.empty(embed_dim, dtype=dtype, device=device))
patch_bias = False
self.patch_embedding = operations.Conv2d(
in_channels=num_channels,
out_channels=embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=False,
bias=patch_bias,
dtype=dtype,
device=device
)
num_patches = (image_size // patch_size) ** 2
num_positions = num_patches + 1
self.position_embedding = operations.Embedding(num_positions, embed_dim, dtype=dtype, device=device)
self.position_embedding = operations.Embedding(num_patches, embed_dim, dtype=dtype, device=device)
def forward(self, pixel_values):
embeds = self.patch_embedding(pixel_values).flatten(2).transpose(1, 2)
return torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1) + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
if self.class_embedding is not None:
embeds = torch.cat([comfy.ops.cast_to_input(self.class_embedding, embeds).expand(pixel_values.shape[0], 1, -1), embeds], dim=1)
return embeds + comfy.ops.cast_to_input(self.position_embedding.weight, embeds)
class CLIPVision(torch.nn.Module):
@@ -170,9 +179,15 @@ class CLIPVision(torch.nn.Module):
heads = config_dict["num_attention_heads"]
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
model_type = config_dict["model_type"]
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], dtype=dtype, device=device, operations=operations)
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.embeddings = CLIPVisionEmbeddings(embed_dim, config_dict["num_channels"], config_dict["patch_size"], config_dict["image_size"], model_type=model_type, dtype=dtype, device=device, operations=operations)
if model_type == "siglip_vision_model":
self.pre_layrnorm = lambda a: a
self.output_layernorm = True
else:
self.pre_layrnorm = operations.LayerNorm(embed_dim)
self.output_layernorm = False
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.post_layernorm = operations.LayerNorm(embed_dim)
@@ -181,14 +196,21 @@ class CLIPVision(torch.nn.Module):
x = self.pre_layrnorm(x)
#TODO: attention_mask?
x, i = self.encoder(x, mask=None, intermediate_output=intermediate_output)
pooled_output = self.post_layernorm(x[:, 0, :])
if self.output_layernorm:
x = self.post_layernorm(x)
pooled_output = x
else:
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
self.vision_model = CLIPVision(config_dict, dtype, device, operations)
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
if "projection_dim" in config_dict:
self.visual_projection = operations.Linear(config_dict["hidden_size"], config_dict["projection_dim"], bias=False)
else:
self.visual_projection = lambda a: a
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)

View File

@@ -16,9 +16,9 @@ class Output:
def __setitem__(self, key, item):
setattr(self, key, item)
def clip_preprocess(image, size=224):
mean = torch.tensor([ 0.48145466,0.4578275,0.40821073], device=image.device, dtype=image.dtype)
std = torch.tensor([0.26862954,0.26130258,0.27577711], device=image.device, dtype=image.dtype)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]):
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
if not (image.shape[2] == size and image.shape[3] == size):
scale = (size / min(image.shape[2], image.shape[3]))
@@ -35,6 +35,8 @@ class ClipVisionModel():
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
@@ -51,7 +53,7 @@ class ClipVisionModel():
def encode_image(self, image):
comfy.model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
@@ -94,7 +96,9 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")

View File

@@ -0,0 +1,13 @@
{
"num_channels": 3,
"hidden_act": "gelu_pytorch_tanh",
"hidden_size": 1152,
"image_size": 384,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"patch_size": 14,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5]
}

View File

@@ -35,7 +35,7 @@ import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet
import comfy.cldm.dit_embedder
def broadcast_image_to(tensor, target_batch_size, batched_number):
current_batch_size = tensor.shape[0]
@@ -78,6 +78,7 @@ class ControlBase:
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
self.preprocess_image = lambda a: a
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ class ControlBase:
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
c.preprocess_image = self.preprocess_image
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@@ -181,7 +183,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
super().__init__()
self.control_model = control_model
self.load_device = load_device
@@ -196,6 +198,7 @@ class ControlNet(ControlBase):
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
self.preprocess_image = preprocess_image
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@@ -224,6 +227,7 @@ class ControlNet(ControlBase):
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
self.cond_hint = self.preprocess_image(self.cond_hint)
if self.vae is not None:
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
logging.debug("unexpected controlnet keys: {}".format(unexpected))
return control_model
def load_controlnet_mmdit(sd, model_options={}):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
return control
class ControlNetSD35(ControlNet):
def pre_run(self, model, percent_to_timestep_function):
if self.control_model.double_y_emb:
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
else:
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
super().pre_run(model, percent_to_timestep_function)
def copy(self):
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
c.control_model = self.control_model
c.control_model_wrapped = self.control_model_wrapped
self.copy_to(c)
return c
def load_controlnet_sd35(sd, model_options={}):
control_type = -1
if "control_type" in sd:
control_type = round(sd.pop("control_type").item())
# blur_cnet = control_type == 0
canny_cnet = control_type == 1
depth_cnet = control_type == 2
print(control_type, canny_cnet, depth_cnet)
new_sd = {}
for k in comfy.utils.MMDIT_MAP_BASIC:
if k[1] in sd:
new_sd[k[0]] = sd.pop(k[1])
for k in sd:
new_sd[k] = sd[k]
sd = new_sd
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
depth = y_emb_shape[0] // 64
hidden_size = 64 * depth
num_heads = depth
head_dim = hidden_size // num_heads
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
load_device = comfy.model_management.get_torch_device()
offload_device = comfy.model_management.unet_offload_device()
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
operations = model_options.get("custom_operations", None)
if operations is None:
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
patch_size=2,
in_chans=16,
num_layers=num_blocks,
main_model_double=depth,
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
attention_head_dim=head_dim,
num_attention_heads=num_heads,
adm_in_channels=2048,
device=offload_device,
dtype=unet_dtype,
operations=operations)
control_model = controlnet_load_state_dict(control_model, sd)
latent_format = comfy.latent_formats.SD3()
preprocess_image = lambda a: a
if canny_cnet:
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
elif depth_cnet:
preprocess_image = lambda a: 1.0 - a
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
return control
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
@@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
else:
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux

View File

@@ -216,3 +216,7 @@ class Mochi(LatentFormat):
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
latents_std = self.latents_std.to(latent.device, latent.dtype)
return latent * latents_std / self.scale_factor + latents_mean
class LTXV(LatentFormat):
latent_channels = 128

View File

@@ -20,6 +20,7 @@ import comfy.ldm.common_dit
@dataclass
class FluxParams:
in_channels: int
out_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
@@ -29,6 +30,7 @@ class FluxParams:
depth_single_blocks: int
axes_dim: list
theta: int
patch_size: int
qkv_bias: bool
guidance_embed: bool
@@ -43,8 +45,9 @@ class Flux(nn.Module):
self.dtype = dtype
params = FluxParams(**kwargs)
self.params = params
self.in_channels = params.in_channels * 2 * 2
self.out_channels = self.in_channels
self.patch_size = params.patch_size
self.in_channels = params.in_channels * params.patch_size * params.patch_size
self.out_channels = params.out_channels * params.patch_size * params.patch_size
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
@@ -165,7 +168,7 @@ class Flux(nn.Module):
def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

25
comfy/ldm/flux/redux.py Normal file
View File

@@ -0,0 +1,25 @@
import torch
import comfy.ops
ops = comfy.ops.manual_cast
class ReduxImageEncoder(torch.nn.Module):
def __init__(
self,
redux_dim: int = 1152,
txt_in_features: int = 4096,
device=None,
dtype=None,
) -> None:
super().__init__()
self.redux_dim = redux_dim
self.device = device
self.dtype = dtype
self.redux_up = ops.Linear(redux_dim, txt_in_features * 3, dtype=dtype)
self.redux_down = ops.Linear(txt_in_features * 3, txt_in_features, dtype=dtype)
def forward(self, sigclip_embeds) -> torch.Tensor:
projected_x = self.redux_down(torch.nn.functional.silu(self.redux_up(sigclip_embeds)))
return projected_x

View File

@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
style=None,
return_dict=False,
control=None,
transformer_options=None,
transformer_options={},
):
"""
Forward pass of the encoder.
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
return_dict: bool
Whether to return a dictionary.
"""
#import pdb
#pdb.set_trace()
patches_replace = transformer_options.get("patches_replace", {})
encoder_hidden_states = context
text_states = encoder_hidden_states # 2,77,1024
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
# Concatenate all extra vectors
c = t + self.extra_embedder(extra_vec) # [B, D]
blocks_replace = patches_replace.get("dit", {})
controls = None
if control:
controls = control.get("output", None)
@@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
else:
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
skip = None
if ("double_block", layer) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
return out
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
if layer < (self.depth // 2 - 1):
skips.append(x)

View File

@@ -0,0 +1,514 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
from .symmetric_patchifier import SymmetricPatchifier
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
Args
timesteps (torch.Tensor):
a 1-D Tensor of N indices, one per batch element. These may be fractional.
embedding_dim (int):
the dimension of the output.
flip_sin_to_cos (bool):
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
downscale_freq_shift (float):
Controls the delta between frequencies between dimensions
scale (float):
Scaling factor applied to the embeddings.
max_period (int):
Controls the maximum frequency of the embeddings
Returns
torch.Tensor: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
dtype=None, device=None, operations=None,
):
super().__init__()
self.linear_1 = operations.Linear(in_channels, time_embed_dim, sample_proj_bias, dtype=dtype, device=device)
if cond_proj_dim is not None:
self.cond_proj = operations.Linear(cond_proj_dim, in_channels, bias=False, dtype=dtype, device=device)
else:
self.cond_proj = None
self.act = nn.SiLU()
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = operations.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias, dtype=dtype, device=device)
if post_act_fn is None:
self.post_act = None
# else:
# self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
self.scale = scale
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
scale=self.scale,
)
return t_emb
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim, dtype=dtype, device=device, operations=operations)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
return timesteps_emb
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False, dtype=None, device=None, operations=None):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions, dtype=dtype, device=device, operations=operations
)
self.silu = nn.SiLU()
self.linear = operations.Linear(embedding_dim, 6 * embedding_dim, bias=True, dtype=dtype, device=device)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh", dtype=None, device=None, operations=None):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = operations.Linear(in_features=in_features, out_features=hidden_size, bias=True, dtype=dtype, device=device)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu":
self.act_1 = nn.SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = operations.Linear(in_features=hidden_size, out_features=out_features, bias=True, dtype=dtype, device=device)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
self.proj = operations.Linear(dim_in, dim_out, dtype=dtype, device=device)
def forward(self, x):
return torch.nn.functional.gelu(self.proj(x), approximate="tanh")
class FeedForward(nn.Module):
def __init__(self, dim, dim_out, mult=4, glu=False, dropout=0., dtype=None, device=None, operations=None):
super().__init__()
inner_dim = int(dim * mult)
project_in = GELU_approx(dim, inner_dim, dtype=dtype, device=device, operations=operations)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
operations.Linear(inner_dim, dim_out, dtype=dtype, device=device)
)
def forward(self, x):
return self.net(x)
def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
return out
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
inner_dim = dim_head * heads
context_dim = query_dim if context_dim is None else context_dim
self.attn_precision = attn_precision
self.heads = heads
self.dim_head = dim_head
self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None, pe=None):
q = self.to_q(x)
context = x if context is None else context
k = self.to_k(context)
v = self.to_v(context)
q = self.q_norm(q)
k = self.k_norm(k)
if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
self.attn_precision = attn_precision
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, context_dim=None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.ff = FeedForward(dim, dim_out=dim, glu=True, dtype=dtype, device=device, operations=operations)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
x += self.attn2(x, context=context, mask=attention_mask)
y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
return x
def get_fractional_positions(indices_grid, max_pos):
fractional_positions = torch.stack(
[
indices_grid[:, i] / max_pos[i]
for i in range(3)
],
dim=-1,
)
return fractional_positions
def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
fractional_positions = get_fractional_positions(indices_grid, max_pos)
start = 1
end = theta
device = fractional_positions.device
indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
indices = indices * math.pi / 2
freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
class LTXVModel(torch.nn.Module):
def __init__(self,
in_channels=128,
cross_attention_dim=2048,
attention_head_dim=64,
num_attention_heads=32,
caption_channels=4096,
num_layers=28,
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=operations
)
# self.adaln_single.linear = operations.Linear(self.inner_dim, 4 * self.inner_dim, bias=True, dtype=dtype, device=device)
self.caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
# attn_precision=attn_precision,
dtype=dtype, device=device, operations=operations
)
for d in range(num_layers)
]
)
self.scale_shift_table = nn.Parameter(torch.empty(2, self.inner_dim, dtype=dtype, device=device))
self.norm_out = operations.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.proj_out = operations.Linear(self.inner_dim, self.out_channels, dtype=dtype, device=device)
self.patchifier = SymmetricPatchifier(1)
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
indices_grid = self.patchifier.get_grid(
orig_num_frames=x.shape[2],
orig_height=x.shape[3],
orig_width=x.shape[4],
batch_size=x.shape[0],
scale_grid=((1 / frame_rate) * 8, 32, 32),
device=x.device,
)
if guiding_latent is not None:
ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
ts *= input_ts
ts[:, :, 0] = 0.0
timestep = self.patchifier.patchify(ts)
input_x = x.clone()
x[:, :, 0] = guiding_latent[:, :, 0]
orig_shape = list(x.shape)
x = self.patchifier.patchify(x)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
attention_mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1]))
attention_mask = attention_mask.masked_fill(attention_mask.to(torch.bool), float("-inf")) # not sure about this
# attention_mask = (context != 0).any(dim=2).to(dtype=x.dtype)
pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=x.dtype,
)
# Second dimension is 1 or number of tokens (if timestep_per_token)
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(
batch_size, -1, embedded_timestep.shape[-1]
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = x.shape[0]
context = self.caption_projection(context)
context = context.view(
batch_size, -1, x.shape[-1]
)
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
x,
context=context,
attention_mask=attention_mask,
timestep=timestep,
pe=pe
)
# 3. Output
scale_shift_values = (
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
)
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = self.proj_out(x)
x = self.patchifier.unpatchify(
latents=x,
output_height=orig_shape[3],
output_width=orig_shape[4],
output_num_frames=orig_shape[2],
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
if guiding_latent is not None:
x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
# print("res", x)
return x

View File

@@ -0,0 +1,105 @@
from abc import ABC, abstractmethod
from typing import Tuple
import torch
from einops import rearrange
from torch import Tensor
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
elif dims_to_append == 0:
return x
return x[(...,) + (None,) * dims_to_append]
class Patchifier(ABC):
def __init__(self, patch_size: int):
super().__init__()
self._patch_size = (1, patch_size, patch_size)
@abstractmethod
def patchify(
self, latents: Tensor, frame_rates: Tensor, scale_grid: bool
) -> Tuple[Tensor, Tensor]:
pass
@abstractmethod
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
pass
@property
def patch_size(self):
return self._patch_size
def get_grid(
self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
):
f = orig_num_frames // self._patch_size[0]
h = orig_height // self._patch_size[1]
w = orig_width // self._patch_size[2]
grid_h = torch.arange(h, dtype=torch.float32, device=device)
grid_w = torch.arange(w, dtype=torch.float32, device=device)
grid_f = torch.arange(f, dtype=torch.float32, device=device)
grid = torch.meshgrid(grid_f, grid_h, grid_w)
grid = torch.stack(grid, dim=0)
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
if scale_grid is not None:
for i in range(3):
if isinstance(scale_grid[i], Tensor):
scale = append_dims(scale_grid[i], grid.ndim - 1)
else:
scale = scale_grid[i]
grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
return grid
class SymmetricPatchifier(Patchifier):
def patchify(
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
p1=self._patch_size[0],
p2=self._patch_size[1],
p3=self._patch_size[2],
)
return latents
def unpatchify(
self,
latents: Tensor,
output_height: int,
output_width: int,
output_num_frames: int,
out_channels: int,
) -> Tuple[Tensor, Tensor]:
output_height = output_height // self._patch_size[1]
output_width = output_width // self._patch_size[2]
latents = rearrange(
latents,
"b (f h w) (c p q) -> b c f (h p) (w q) ",
f=output_num_frames,
h=output_height,
w=output_width,
p=self._patch_size[1],
q=self._patch_size[2],
)
return latents

View File

@@ -0,0 +1,64 @@
from typing import Tuple, Union
import torch
import torch.nn as nn
import comfy.ops
ops = comfy.ops.disable_weight_init
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = ops.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
groups=groups,
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, self.time_kernel_size - 1, 1, 1)
)
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
last_frame_pad = x[:, :, -1:, :, :].repeat(
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
)
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight

View File

@@ -0,0 +1,698 @@
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
from typing import Any, Mapping, Optional, Tuple, Union, List
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
class Encoder(nn.Module):
r"""
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
):
super().__init__()
self.patch_size = patch_size
self.norm_layer = norm_layer
self.latent_channels = out_channels
self.latent_log_var = latent_log_var
self.blocks_desc = blocks
in_channels = in_channels * patch_size**2
output_channel = base_channels
self.conv_in = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.down_blocks = nn.ModuleList([])
for block_name, block_params in blocks:
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 1, 1),
causal=True,
)
elif block_name == "compress_space":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(1, 2, 2),
causal=True,
)
elif block_name == "compress_all":
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
block = make_conv_nd(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
kernel_size=3,
stride=(2, 2, 2),
causal=True,
)
else:
raise ValueError(f"unknown block: {block_name}")
self.down_blocks.append(block)
# out
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
conv_out_channels = out_channels
if latent_log_var == "per_channel":
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
dims, output_channel, conv_out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample)
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...]
num_dims = sample.dim()
if num_dims == 4:
# For shape (B, C, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
elif num_dims == 5:
# For shape (B, C, F, H, W)
repeated_last_channel = last_channel.repeat(
1, sample.shape[1] - 2, 1, 1, 1
)
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
return sample
class Decoder(nn.Module):
r"""
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
Args:
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
The number of dimensions to use in convolutions.
in_channels (`int`, *optional*, defaults to 3):
The number of input channels.
out_channels (`int`, *optional*, defaults to 3):
The number of output channels.
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
The blocks to use. Each block is a tuple of the block name and the number of layers.
base_channels (`int`, *optional*, defaults to 128):
The number of output channels for the first convolutional layer.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups for normalization.
patch_size (`int`, *optional*, defaults to 1):
The patch size to use. Should be a power of 2.
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
causal (`bool`, *optional*, defaults to `True`):
Whether to use causal convolutions or not.
"""
def __init__(
self,
dims,
in_channels: int = 3,
out_channels: int = 3,
blocks=[("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
patch_size: int = 1,
norm_layer: str = "group_norm",
causal: bool = True,
):
super().__init__()
self.patch_size = patch_size
self.layers_per_block = layers_per_block
out_channels = out_channels * patch_size**2
self.causal = causal
self.blocks_desc = blocks
# Compute output channel to be product of all channel-multiplier blocks
output_channel = base_channels
for block_name, block_params in list(reversed(blocks)):
block_params = block_params if isinstance(block_params, dict) else {}
if block_name == "res_x_y":
output_channel = output_channel * block_params.get("multiplier", 2)
self.conv_in = make_conv_nd(
dims,
in_channels,
output_channel,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.up_blocks = nn.ModuleList([])
for block_name, block_params in list(reversed(blocks)):
input_channel = output_channel
if isinstance(block_params, int):
block_params = {"num_layers": block_params}
if block_name == "res_x":
block = UNetMidBlock3D(
dims=dims,
in_channels=input_channel,
num_layers=block_params["num_layers"],
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
block = ResnetBlock3D(
dims=dims,
in_channels=input_channel,
out_channels=output_channel,
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(2, 1, 1)
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
)
elif block_name == "compress_all":
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 2, 2),
residual=block_params.get("residual", False),
)
else:
raise ValueError(f"unknown layer: {block_name}")
self.up_blocks.append(block)
if norm_layer == "group_norm":
self.conv_norm_out = nn.GroupNorm(
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
)
elif norm_layer == "pixel_norm":
self.conv_norm_out = PixelNorm()
elif norm_layer == "layer_norm":
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
dims, output_channel, out_channels, 3, padding=1, causal=True
)
self.gradient_checkpointing = False
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class."""
# assert target_shape is not None, "target_shape must be provided"
sample = self.conv_in(sample, causal=self.causal)
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
checkpoint_fn = (
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
if self.gradient_checkpointing and self.training
else lambda x: x
)
sample = sample.to(upscale_dtype)
for up_block in self.up_blocks:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample, causal=self.causal)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
class UNetMidBlock3D(nn.Module):
"""
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
Args:
in_channels (`int`): The number of input channels.
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
in_channels, height, width)`.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_groups: int = 32,
norm_layer: str = "group_norm",
):
super().__init__()
resnet_groups = (
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
)
self.res_blocks = nn.ModuleList(
[
ResnetBlock3D(
dims=dims,
in_channels=in_channels,
out_channels=in_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
norm_layer=norm_layer,
)
for _ in range(num_layers)
]
)
def forward(
self, hidden_states: torch.FloatTensor, causal: bool = True
) -> torch.FloatTensor:
for resnet in self.res_blocks:
hidden_states = resnet(hidden_states, causal=causal)
return hidden_states
class DepthToSpaceUpsample(nn.Module):
def __init__(self, dims, in_channels, stride, residual=False):
super().__init__()
self.stride = stride
self.out_channels = math.prod(stride) * in_channels
self.conv = make_conv_nd(
dims=dims,
in_channels=in_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
causal=True,
)
self.residual = residual
def forward(self, x, causal: bool = True):
if self.residual:
# Reshape and duplicate the input to match the output shape
x_in = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
if self.stride[0] == 2:
x_in = x_in[:, :, 1:, :, :]
x = self.conv(x, causal=causal)
x = rearrange(
x,
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
if self.stride[0] == 2:
x = x[:, :, 1:, :, :]
if self.residual:
x = x + x_in
return x
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
x = self.norm(x)
x = rearrange(x, "b d h w c -> b c d h w")
return x
class ResnetBlock3D(nn.Module):
r"""
A Resnet block.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to be `None`):
The number of output channels for the first conv layer. If None, same as `in_channels`.
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
"""
def __init__(
self,
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: Optional[int] = None,
dropout: float = 0.0,
groups: int = 32,
eps: float = 1e-6,
norm_layer: str = "group_norm",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
if norm_layer == "group_norm":
self.norm1 = nn.GroupNorm(
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm1 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
self.non_linearity = nn.SiLU()
self.conv1 = make_conv_nd(
dims,
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
if norm_layer == "group_norm":
self.norm2 = nn.GroupNorm(
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
)
elif norm_layer == "pixel_norm":
self.norm2 = PixelNorm()
elif norm_layer == "layer_norm":
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = make_conv_nd(
dims,
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
causal=True,
)
self.conv_shortcut = (
make_linear_nd(
dims=dims, in_channels=in_channels, out_channels=out_channels
)
if in_channels != out_channels
else nn.Identity()
)
self.norm3 = (
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
if in_channels != out_channels
else nn.Identity()
)
def forward(
self,
input_tensor: torch.FloatTensor,
causal: bool = True,
) -> torch.FloatTensor:
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv1(hidden_states, causal=causal)
hidden_states = self.norm2(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, causal=causal)
input_tensor = self.norm3(input_tensor)
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = input_tensor + hidden_states
return output_tensor
def patchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
else:
raise ValueError(f"Invalid input shape: {x.shape}")
return x
def unpatchify(x, patch_size_hw, patch_size_t=1):
if patch_size_hw == 1 and patch_size_t == 1:
return x
if x.dim() == 4:
x = rearrange(
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
)
elif x.dim() == 5:
x = rearrange(
x,
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
p=patch_size_t,
q=patch_size_hw,
r=patch_size_hw,
)
return x
class processor(nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
def normalize(self, x):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
def __init__(self):
super().__init__()
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"blocks": [
["res_x", 4],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x_y", 1],
["res_x", 3],
["compress_all", 1],
["res_x", 3],
["res_x", 4],
],
"scaling_factor": 1.0,
"norm_layer": "pixel_norm",
"patch_size": 4,
"latent_log_var": "uniform",
"use_quant_conv": False,
"causal_decoder": False,
}
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
)
self.encoder = Encoder(
dims=config["dims"],
in_channels=config.get("in_channels", 3),
out_channels=config["latent_channels"],
blocks=config.get("encoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
)
self.decoder = Decoder(
dims=config["dims"],
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("blocks")),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
)
self.per_channel_statistics = processor()
def encode(self, x):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x):
return self.decoder(self.per_channel_statistics.un_normalize(x))

View File

@@ -0,0 +1,83 @@
from typing import Tuple, Union
import torch
from .dual_conv3d import DualConv3d
from .causal_conv3d import CausalConv3d
import comfy.ops
ops = comfy.ops.disable_weight_init
def make_conv_nd(
dims: Union[int, Tuple[int, int]],
in_channels: int,
out_channels: int,
kernel_size: int,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
causal=False,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == 3:
if causal:
return CausalConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
return ops.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
elif dims == (2, 1):
return DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
def make_linear_nd(
dims: int,
in_channels: int,
out_channels: int,
bias=True,
):
if dims == 2:
return ops.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
elif dims == 3 or dims == (2, 1):
return ops.Conv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
)
else:
raise ValueError(f"unsupported dimensions: {dims}")

View File

@@ -0,0 +1,195 @@
import math
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class DualConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if kernel_size == (1, 1, 1):
raise ValueError(
"kernel_size must be greater than 1. Use make_linear_nd instead."
)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
# Set parameters for convolutions
self.groups = groups
self.bias = bias
# Define the size of the channels after the first convolution
intermediate_channels = (
out_channels if in_channels < out_channels else in_channels
)
# Define parameters for the first convolution
self.weight1 = nn.Parameter(
torch.Tensor(
intermediate_channels,
in_channels // groups,
1,
kernel_size[1],
kernel_size[2],
)
)
self.stride1 = (1, stride[1], stride[2])
self.padding1 = (0, padding[1], padding[2])
self.dilation1 = (1, dilation[1], dilation[2])
if bias:
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
else:
self.register_parameter("bias1", None)
# Define parameters for the second convolution
self.weight2 = nn.Parameter(
torch.Tensor(
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
)
)
self.stride2 = (stride[0], 1, 1)
self.padding2 = (padding[0], 0, 0)
self.dilation2 = (dilation[0], 1, 1)
if bias:
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter("bias2", None)
# Initialize weights and biases
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
if self.bias:
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
bound1 = 1 / math.sqrt(fan_in1)
nn.init.uniform_(self.bias1, -bound1, bound1)
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
bound2 = 1 / math.sqrt(fan_in2)
nn.init.uniform_(self.bias2, -bound2, bound2)
def forward(self, x, use_conv3d=False, skip_time_conv=False):
if use_conv3d:
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
else:
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
def forward_with_3d(self, x, skip_time_conv):
# First convolution
x = F.conv3d(
x,
self.weight1,
self.bias1,
self.stride1,
self.padding1,
self.dilation1,
self.groups,
)
if skip_time_conv:
return x
# Second convolution
x = F.conv3d(
x,
self.weight2,
self.bias2,
self.stride2,
self.padding2,
self.dilation2,
self.groups,
)
return x
def forward_with_2d(self, x, skip_time_conv):
b, c, d, h, w = x.shape
# First 2D convolution
x = rearrange(x, "b c d h w -> (b d) c h w")
# Squeeze the depth dimension out of weight1 since it's 1
weight1 = self.weight1.squeeze(2)
# Select stride, padding, and dilation for the 2D convolution
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
_, _, h, w = x.shape
if skip_time_conv:
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
return x
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
# Reshape weight2 to match the expected dimensions for conv1d
weight2 = self.weight2.squeeze(-1).squeeze(-1)
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
@property
def weight(self):
return self.weight2
def test_dual_conv3d_consistency():
# Initialize parameters
in_channels = 3
out_channels = 5
kernel_size = (3, 3, 3)
stride = (2, 2, 2)
padding = (1, 1, 1)
# Create an instance of the DualConv3d class
dual_conv3d = DualConv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=True,
)
# Example input tensor
test_input = torch.randn(1, 3, 10, 10, 10)
# Perform forward passes with both 3D and 2D settings
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
output_2d = dual_conv3d(test_input, use_conv3d=False)
# Assert that the outputs from both methods are sufficiently close
assert torch.allclose(
output_conv3d, output_2d, atol=1e-6
), "Outputs are not consistent between 3D and 2D convolutions."

View File

@@ -0,0 +1,12 @@
import torch
from torch import nn
class PixelNorm(nn.Module):
def __init__(self, dim=1, eps=1e-8):
super(PixelNorm, self).__init__()
self.dim = dim
self.eps = eps
def forward(self, x):
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)

View File

@@ -299,7 +299,10 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if len(mask.shape) == 2:
s1 += mask[i:end]
else:
s1 += mask[:, i:end]
if mask.shape[1] == 1:
s1 += mask
else:
s1 += mask[:, i:end]
s2 = s1.softmax(dim=-1).to(v.dtype)
del s1
@@ -372,10 +375,10 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
)
if mask is not None:
pad = 8 - q.shape[1] % 8
mask_out = torch.empty([q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device)
mask_out[:, :, :mask.shape[-1]] = mask
mask = mask_out[:, :, :mask.shape[-1]]
pad = 8 - mask.shape[-1] % 8
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
mask_out[..., :mask.shape[-1]] = mask
mask = mask_out[..., :mask.shape[-1]]
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)

View File

@@ -234,6 +234,8 @@ def efficient_dot_product_attention(
def get_mask_chunk(chunk_idx: int) -> Tensor:
if mask is None:
return None
if mask.shape[1] == 1:
return mask
chunk = min(query_chunk_size, q_tokens)
return mask[:,chunk_idx:chunk_idx + chunk]

View File

@@ -49,10 +49,20 @@ def load_lora(lora, to_load):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
reshape_name = "{}.reshape_weight".format(x)
reshape = None
if reshape_name in lora.keys():
try:
reshape = lora[reshape_name].tolist()
loaded_keys.add(reshape_name)
except:
pass
regular_lora = "{}.lora_up.weight".format(x)
diffusers_lora = "{}_lora.up.weight".format(x)
diffusers2_lora = "{}.lora_B.weight".format(x)
diffusers3_lora = "{}.lora.up.weight".format(x)
mochi_lora = "{}.lora_B".format(x)
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
A_name = None
@@ -72,6 +82,10 @@ def load_lora(lora, to_load):
A_name = diffusers3_lora
B_name = "{}.lora.down.weight".format(x)
mid_name = None
elif mochi_lora in lora.keys():
A_name = mochi_lora
B_name = "{}.lora_A".format(x)
mid_name = None
elif transformers_lora in lora.keys():
A_name = transformers_lora
B_name ="{}.lora_linear_layer.down.weight".format(x)
@@ -82,7 +96,7 @@ def load_lora(lora, to_load):
if mid_name is not None and mid_name in lora.keys():
mid = lora[mid_name]
loaded_keys.add(mid_name)
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale))
patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
loaded_keys.add(A_name)
loaded_keys.add(B_name)
@@ -193,6 +207,12 @@ def load_lora(lora, to_load):
patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
loaded_keys.add(diff_bias_name)
set_weight_name = "{}.set_weight".format(x)
set_weight = lora.get(set_weight_name, None)
if set_weight is not None:
patch_dict[to_load[x]] = ("set", (set_weight,))
loaded_keys.add(set_weight_name)
for x in lora.keys():
if x not in loaded_keys:
logging.warning("lora key not loaded: {}".format(x))
@@ -282,11 +302,14 @@ def model_lora_keys_unet(model, key_map={}):
sdk = sd.keys()
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
if k.startswith("diffusion_model."):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
diffusers_keys = comfy.utils.unet_to_diffusers(model.model_config.unet_config)
for k in diffusers_keys:
@@ -344,6 +367,12 @@ def model_lora_keys_unet(model, key_map={}):
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
if isinstance(model, comfy.model_base.GenmoMochi):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map
@@ -440,10 +469,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, diff.shape, weight.shape))
else:
weight += function(strength * comfy.model_management.cast_to_device(diff, weight.device, weight.dtype))
elif patch_type == "set":
weight.copy_(v[0])
elif patch_type == "lora": #lora/locon
mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
reshape = v[5]
if reshape is not None:
weight = pad_tensor_to_shape(weight, reshape)
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:

17
comfy/lora_convert.py Normal file
View File

@@ -0,0 +1,17 @@
import torch
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
return sd_out
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
return sd

View File

@@ -30,6 +30,7 @@ import comfy.ldm.hydit.models
import comfy.ldm.audio.dit
import comfy.ldm.audio.embedders
import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.model_management
import comfy.conds
@@ -710,6 +711,38 @@ class Flux(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
def concat_cond(self, **kwargs):
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
out_channels = self.model_config.unet_config["out_channels"]
if num_channels <= out_channels:
return None
image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None)
device = kwargs["device"]
if image is None:
image = torch.zeros_like(noise)
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
image = utils.resize_to_batch_size(image, noise.shape[0])
image = self.process_latent_in(image)
if num_channels <= out_channels * 2:
return image
#inpaint model
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
mask = torch.ones_like(noise)[:, :1]
mask = torch.mean(mask, dim=1, keepdim=True)
print(mask.shape)
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return torch.cat((image, mask), dim=1)
def encode_adm(self, **kwargs):
return kwargs["pooled_output"]
@@ -735,3 +768,23 @@ class GenmoMochi(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class LTXV(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lightricks.model.LTXVModel) #TODO
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
guiding_latent = kwargs.get("guiding_latent", None)
if guiding_latent is not None:
out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
return out

View File

@@ -137,6 +137,12 @@ def detect_unet_config(state_dict, key_prefix):
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
patch_size = 2
dit_config["patch_size"] = patch_size
in_key = "{}img_in.weight".format(key_prefix)
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
@@ -177,6 +183,10 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["rope_theta"] = 10000.0
return dit_config
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
return dit_config
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None

View File

@@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None):
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
return torch.float32
if args.fp64_unet:
return torch.float64
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:
@@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
# None means no manual cast
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
if weight_dtype == torch.float32:
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)

View File

@@ -367,20 +367,33 @@ class ModelPatcher:
else:
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
def _load_list(self):
loading = []
for n, m in self.model.named_modules():
params = []
skip = False
for name, param in m.named_parameters(recurse=False):
params.append(name)
for name, param in m.named_parameters(recurse=True):
if name not in params:
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
loading = []
for n, m in self.model.named_modules():
if hasattr(m, "comfy_cast_weights") or hasattr(m, "weight"):
loading.append((comfy.model_management.module_size(m), n, m))
loading = self._load_list()
load_completely = []
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
lowvram_weight = False
@@ -416,22 +429,22 @@ class ModelPatcher:
if m.comfy_cast_weights:
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m))
load_completely.append((module_mem, n, m, params))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
params = x[3]
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
for param in params:
self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
@@ -505,14 +518,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = comfy.model_management.module_size(m)
unload_list.append((module_mem, n, m))
unload_list = self._load_list()
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
@@ -520,32 +526,42 @@ class ModelPatcher:
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
params = unload[3]
lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
move_weight = True
for param in params:
key = "{}.{}".format(n, param)
bk = self.backup.get(key, None)
if bk is not None:
if not lowvram_possible:
move_weight = False
break
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
comfy.utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
m.to(device_to)
if lowvram_possible:
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter

View File

@@ -8,6 +8,7 @@ from .ldm.cascade.stage_a import StageA
from .ldm.cascade.stage_c_coder import StageC_coder
from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import yaml
import comfy.utils
@@ -27,12 +28,16 @@ import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.long_clipl
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.model_patcher
import comfy.lora
import comfy.lora_convert
import comfy.t2i_adapter.adapter
import comfy.taesd.taesd
import comfy.ldm.flux.redux
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
key_map = {}
if model is not None:
@@ -40,6 +45,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
if clip is not None:
key_map = comfy.lora.model_lora_keys_clip(clip.cond_stage_model, key_map)
lora = comfy.lora_convert.convert_lora(lora)
loaded = comfy.lora.load_lora(lora, key_map)
if model is not None:
new_modelpatcher = model.clone()
@@ -257,6 +263,14 @@ class VAE:
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
self.working_dtypes = [torch.float16, torch.float32]
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: #lightricks ltxv
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE()
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -356,7 +370,9 @@ class VAE:
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
pixel_samples = self.decode_tiled_3d(samples_in)
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
@@ -420,6 +436,12 @@ class VAE:
def get_sd(self):
return self.first_stage_model.state_dict()
def spacial_compression_decode(self):
try:
return self.upscale_ratio[-1]
except:
return self.upscale_ratio
class StyleModel:
def __init__(self, model, device="cpu"):
self.model = model
@@ -433,6 +455,8 @@ def load_style_model(ckpt_path):
keys = model_data.keys()
if "style_embedding" in keys:
model = comfy.t2i_adapter.adapter.StyleAdapter(width=1024, context_dim=768, num_head=8, n_layes=3, num_token=8)
elif "redux_down.weight" in keys:
model = comfy.ldm.flux.redux.ReduxImageEncoder()
else:
raise Exception("invalid style model {}".format(ckpt_path))
model.load_state_dict(model_data)
@@ -446,6 +470,7 @@ class CLIPType(Enum):
HUNYUAN_DIT = 5
FLUX = 6
MOCHI = 7
LTXV = 8
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
@@ -524,6 +549,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=False, t5=True, **t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer

View File

@@ -11,6 +11,7 @@ import comfy.text_encoders.aura_t5
import comfy.text_encoders.hydit
import comfy.text_encoders.flux
import comfy.text_encoders.genmo
import comfy.text_encoders.lt
from . import supported_models_base
from . import latent_formats
@@ -658,6 +659,15 @@ class Flux(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
class FluxInpaint(Flux):
unet_config = {
"image_model": "flux",
"guidance_embed": True,
"in_channels": 96,
}
supported_inference_dtypes = [torch.bfloat16, torch.float32]
class FluxSchnell(Flux):
unet_config = {
"image_model": "flux",
@@ -702,7 +712,34 @@ class GenmoMochi(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.genmo.MochiT5Tokenizer, comfy.text_encoders.genmo.mochi_te(**t5_detect))
class LTXV(supported_models_base.BASE):
unet_config = {
"image_model": "ltxv",
}
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi]
sampling_settings = {
"shift": 2.37,
}
unet_extra_config = {}
latent_format = latent_formats.LTXV
memory_usage_factor = 2.7
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
models += [SVD_img2vid]

18
comfy/text_encoders/lt.py Normal file
View File

@@ -0,0 +1,18 @@
from comfy import sd1_clip
import os
from transformers import T5TokenizerFast
import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)

181
comfy_extras/nodes_lt.py Normal file
View File

@@ -0,0 +1,181 @@
import nodes
import node_helpers
import torch
import comfy.model_management
import comfy.model_sampling
import math
class EmptyLTXVLatentVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("LATENT",)
FUNCTION = "generate"
CATEGORY = "latent/video/ltxv"
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
return ({"samples": latent}, )
class LTXVImgToVideo:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"vae": ("VAE",),
"image": ("IMAGE",),
"width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "LATENT")
RETURN_NAMES = ("positive", "negative", "latent")
CATEGORY = "conditioning/video_models"
FUNCTION = "generate"
def generate(self, positive, negative, image, vae, width, height, length, batch_size):
pixels = comfy.utils.common_upscale(image.movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
encode_pixels = pixels[:, :, :, :3]
t = vae.encode(encode_pixels)
positive = node_helpers.conditioning_set_values(positive, {"guiding_latent": t})
negative = node_helpers.conditioning_set_values(negative, {"guiding_latent": t})
latent = torch.zeros([batch_size, 128, ((length - 1) // 8) + 1, height // 32, width // 32], device=comfy.model_management.intermediate_device())
latent[:, :, :t.shape[2]] = t
return (positive, negative, {"samples": latent}, )
class LTXVConditioning:
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
"frame_rate": ("FLOAT", {"default": 25.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
}}
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
RETURN_NAMES = ("positive", "negative")
FUNCTION = "append"
CATEGORY = "conditioning/video_models"
def append(self, positive, negative, frame_rate):
positive = node_helpers.conditioning_set_values(positive, {"frame_rate": frame_rate})
negative = node_helpers.conditioning_set_values(negative, {"frame_rate": frame_rate})
return (positive, negative)
class ModelSamplingLTXV:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def patch(self, model, max_shift, base_shift, latent=None):
m = model.clone()
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
shift = (tokens) * mm + b
sampling_base = comfy.model_sampling.ModelSamplingFlux
sampling_type = comfy.model_sampling.CONST
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass
model_sampling = ModelSamplingAdvanced(model.model.model_config)
model_sampling.set_parameters(shift=shift)
m.add_object_patch("model_sampling", model_sampling)
return (m, )
class LTXVScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"max_shift": ("FLOAT", {"default": 2.05, "min": 0.0, "max": 100.0, "step":0.01}),
"base_shift": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 100.0, "step":0.01}),
"stretch": ("BOOLEAN", {
"default": True,
"tooltip": "Stretch the sigmas to be in the range [terminal, 1]."
}),
"terminal": (
"FLOAT",
{
"default": 0.1, "min": 0.0, "max": 0.99, "step": 0.01,
"tooltip": "The terminal value of the sigmas after stretching."
},
),
},
"optional": {"latent": ("LATENT",), }
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, max_shift, base_shift, stretch, terminal, latent=None):
if latent is None:
tokens = 4096
else:
tokens = math.prod(latent["samples"].shape[2:])
sigmas = torch.linspace(1.0, 0.0, steps + 1)
x1 = 1024
x2 = 4096
mm = (max_shift - base_shift) / (x2 - x1)
b = base_shift - mm * x1
sigma_shift = (tokens) * mm + b
power = 1
sigmas = torch.where(
sigmas != 0,
math.exp(sigma_shift) / (math.exp(sigma_shift) + (1 / sigmas - 1) ** power),
0,
)
# Stretch sigmas so that its final value matches the given terminal value.
if stretch:
non_zero_mask = sigmas != 0
non_zero_sigmas = sigmas[non_zero_mask]
one_minus_z = 1.0 - non_zero_sigmas
scale_factor = one_minus_z[-1] / (1.0 - terminal)
stretched = 1.0 - (one_minus_z / scale_factor)
sigmas[non_zero_mask] = stretched
return (sigmas,)
NODE_CLASS_MAPPINGS = {
"EmptyLTXVLatentVideo": EmptyLTXVLatentVideo,
"LTXVImgToVideo": LTXVImgToVideo,
"ModelSamplingLTXV": ModelSamplingLTXV,
"LTXVConditioning": LTXVConditioning,
"LTXVScheduler": LTXVScheduler,
}

View File

@@ -301,7 +301,8 @@ class VAEDecodeTiled:
def decode(self, vae, samples, tile_size, overlap=64):
if tile_size < overlap * 4:
overlap = tile_size // 4
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
compression = vae.spacial_compression_decode()
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
return (images, )
@@ -897,7 +898,7 @@ class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi"], ),
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
}}
RETURN_TYPES = ("CLIP",)
FUNCTION = "load_clip"
@@ -915,6 +916,8 @@ class CLIPLoader:
clip_type = comfy.sd.CLIPType.STABLE_AUDIO
elif type == "mochi":
clip_type = comfy.sd.CLIPType.MOCHI
elif type == "ltxv":
clip_type = comfy.sd.CLIPType.LTXV
else:
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
@@ -2136,6 +2139,7 @@ def init_builtin_extra_nodes():
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
"nodes_lt.py",
]
import_failed = []