Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c82741b54 | ||
|
|
15c39ea757 | ||
|
|
b7143b74ce | ||
|
|
61196d8857 | ||
|
|
b4526d3fc3 | ||
|
|
3d802710e7 | ||
|
|
7126ecffde | ||
|
|
ab885b33ba | ||
|
|
839ed3368e | ||
|
|
6e8cdcd3cb | ||
|
|
e5c3f4b87f |
58
README.md
58
README.md
@@ -75,37 +75,37 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
|
|||||||
|
|
||||||
| Keybind | Explanation |
|
| Keybind | Explanation |
|
||||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||||
| Ctrl + Enter | Queue up current graph for generation |
|
| `Ctrl` + `Enter` | Queue up current graph for generation |
|
||||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
| `Ctrl` + `Shift` + `Enter` | Queue up current graph as first for generation |
|
||||||
| Ctrl + Alt + Enter | Cancel current generation |
|
| `Ctrl` + `Alt` + `Enter` | Cancel current generation |
|
||||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
| `Ctrl` + `Z`/`Ctrl` + `Y` | Undo/Redo |
|
||||||
| Ctrl + S | Save workflow |
|
| `Ctrl` + `S` | Save workflow |
|
||||||
| Ctrl + O | Load workflow |
|
| `Ctrl` + `O` | Load workflow |
|
||||||
| Ctrl + A | Select all nodes |
|
| `Ctrl` + `A` | Select all nodes |
|
||||||
| Alt + C | Collapse/uncollapse selected nodes |
|
| `Alt `+ `C` | Collapse/uncollapse selected nodes |
|
||||||
| Ctrl + M | Mute/unmute selected nodes |
|
| `Ctrl` + `M` | Mute/unmute selected nodes |
|
||||||
| Ctrl + B | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
| `Ctrl` + `B` | Bypass selected nodes (acts like the node was removed from the graph and the wires reconnected through) |
|
||||||
| Delete/Backspace | Delete selected nodes |
|
| `Delete`/`Backspace` | Delete selected nodes |
|
||||||
| Ctrl + Backspace | Delete the current graph |
|
| `Ctrl` + `Backspace` | Delete the current graph |
|
||||||
| Space | Move the canvas around when held and moving the cursor |
|
| `Space` | Move the canvas around when held and moving the cursor |
|
||||||
| Ctrl/Shift + Click | Add clicked node to selection |
|
| `Ctrl`/`Shift` + `Click` | Add clicked node to selection |
|
||||||
| Ctrl + C/Ctrl + V | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
| `Ctrl` + `C`/`Ctrl` + `V` | Copy and paste selected nodes (without maintaining connections to outputs of unselected nodes) |
|
||||||
| Ctrl + C/Ctrl + Shift + V | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
| `Ctrl` + `C`/`Ctrl` + `Shift` + `V` | Copy and paste selected nodes (maintaining connections from outputs of unselected nodes to inputs of pasted nodes) |
|
||||||
| Shift + Drag | Move multiple selected nodes at the same time |
|
| `Shift` + `Drag` | Move multiple selected nodes at the same time |
|
||||||
| Ctrl + D | Load default graph |
|
| `Ctrl` + `D` | Load default graph |
|
||||||
| Alt + `+` | Canvas Zoom in |
|
| `Alt` + `+` | Canvas Zoom in |
|
||||||
| Alt + `-` | Canvas Zoom out |
|
| `Alt` + `-` | Canvas Zoom out |
|
||||||
| Ctrl + Shift + LMB + Vertical drag | Canvas Zoom in/out |
|
| `Ctrl` + `Shift` + LMB + Vertical drag | Canvas Zoom in/out |
|
||||||
| P | Pin/Unpin selected nodes |
|
| `P` | Pin/Unpin selected nodes |
|
||||||
| Ctrl + G | Group selected nodes |
|
| `Ctrl` + `G` | Group selected nodes |
|
||||||
| Q | Toggle visibility of the queue |
|
| `Q` | Toggle visibility of the queue |
|
||||||
| H | Toggle visibility of history |
|
| `H` | Toggle visibility of history |
|
||||||
| R | Refresh graph |
|
| `R` | Refresh graph |
|
||||||
| Double-Click LMB | Open node quick search palette |
|
| Double-Click LMB | Open node quick search palette |
|
||||||
| Shift + Drag | Move multiple wires at once |
|
| `Shift` + Drag | Move multiple wires at once |
|
||||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||||
|
|
||||||
Ctrl can also be replaced with Cmd instead for macOS users
|
`Ctrl` can also be replaced with `Cmd` instead for macOS users
|
||||||
|
|
||||||
# Installing
|
# Installing
|
||||||
|
|
||||||
|
|||||||
122
comfy/cldm/dit_embedder.py
Normal file
122
comfy/cldm/dit_embedder.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from einops import rearrange
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetEmbedder(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size: int,
|
||||||
|
patch_size: int,
|
||||||
|
in_chans: int,
|
||||||
|
attention_head_dim: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
adm_in_channels: int,
|
||||||
|
num_layers: int,
|
||||||
|
main_model_double: int,
|
||||||
|
double_y_emb: bool,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
pos_embed_max_size: Optional[int] = None,
|
||||||
|
operations = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.dtype = dtype
|
||||||
|
self.hidden_size = num_attention_heads * attention_head_dim
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.x_embedder = PatchEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=self.hidden_size,
|
||||||
|
strict_img_size=pos_embed_max_size is None,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
self.double_y_emb = double_y_emb
|
||||||
|
if self.double_y_emb:
|
||||||
|
self.orig_y_embedder = VectorEmbedder(
|
||||||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
self.y_embedder = VectorEmbedder(
|
||||||
|
self.hidden_size, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.y_embedder = VectorEmbedder(
|
||||||
|
adm_in_channels, self.hidden_size, dtype, device, operations=operations
|
||||||
|
)
|
||||||
|
|
||||||
|
self.transformer_blocks = nn.ModuleList(
|
||||||
|
DismantledBlock(
|
||||||
|
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
|
||||||
|
dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
)
|
||||||
|
|
||||||
|
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
|
||||||
|
# TODO double check this logic when 8b
|
||||||
|
self.use_y_embedder = True
|
||||||
|
|
||||||
|
self.controlnet_blocks = nn.ModuleList([])
|
||||||
|
for _ in range(len(self.transformer_blocks)):
|
||||||
|
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||||
|
self.controlnet_blocks.append(controlnet_block)
|
||||||
|
|
||||||
|
self.pos_embed_input = PatchEmbed(
|
||||||
|
img_size=img_size,
|
||||||
|
patch_size=patch_size,
|
||||||
|
in_chans=in_chans,
|
||||||
|
embed_dim=self.hidden_size,
|
||||||
|
strict_img_size=False,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
timesteps: torch.Tensor,
|
||||||
|
y: Optional[torch.Tensor] = None,
|
||||||
|
context: Optional[torch.Tensor] = None,
|
||||||
|
hint = None,
|
||||||
|
) -> Tuple[Tensor, List[Tensor]]:
|
||||||
|
x_shape = list(x.shape)
|
||||||
|
x = self.x_embedder(x)
|
||||||
|
if not self.double_y_emb:
|
||||||
|
h = (x_shape[-2] + 1) // self.patch_size
|
||||||
|
w = (x_shape[-1] + 1) // self.patch_size
|
||||||
|
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
|
||||||
|
c = self.t_embedder(timesteps, dtype=x.dtype)
|
||||||
|
if y is not None and self.y_embedder is not None:
|
||||||
|
if self.double_y_emb:
|
||||||
|
y = self.orig_y_embedder(y)
|
||||||
|
y = self.y_embedder(y)
|
||||||
|
c = c + y
|
||||||
|
|
||||||
|
x = x + self.pos_embed_input(hint)
|
||||||
|
|
||||||
|
block_out = ()
|
||||||
|
|
||||||
|
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
|
||||||
|
for i in range(len(self.transformer_blocks)):
|
||||||
|
out = self.transformer_blocks[i](x, c)
|
||||||
|
if not self.double_y_emb:
|
||||||
|
x = out
|
||||||
|
block_out += (self.controlnet_blocks[i](out),) * repeat
|
||||||
|
|
||||||
|
return {"output": block_out}
|
||||||
@@ -60,8 +60,10 @@ fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If
|
|||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
fpunet_group = parser.add_mutually_exclusive_group()
|
fpunet_group = parser.add_mutually_exclusive_group()
|
||||||
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
fpunet_group.add_argument("--fp32-unet", action="store_true", help="Run the diffusion model in fp32.")
|
||||||
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Store unet weights in fp16.")
|
fpunet_group.add_argument("--fp64-unet", action="store_true", help="Run the diffusion model in fp64.")
|
||||||
|
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diffusion model in bf16.")
|
||||||
|
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
|
||||||
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ import comfy.ldm.cascade.controlnet
|
|||||||
import comfy.cldm.mmdit
|
import comfy.cldm.mmdit
|
||||||
import comfy.ldm.hydit.controlnet
|
import comfy.ldm.hydit.controlnet
|
||||||
import comfy.ldm.flux.controlnet
|
import comfy.ldm.flux.controlnet
|
||||||
|
import comfy.cldm.dit_embedder
|
||||||
|
|
||||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||||
current_batch_size = tensor.shape[0]
|
current_batch_size = tensor.shape[0]
|
||||||
@@ -78,6 +78,7 @@ class ControlBase:
|
|||||||
self.concat_mask = False
|
self.concat_mask = False
|
||||||
self.extra_concat_orig = []
|
self.extra_concat_orig = []
|
||||||
self.extra_concat = None
|
self.extra_concat = None
|
||||||
|
self.preprocess_image = lambda a: a
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@@ -129,6 +130,7 @@ class ControlBase:
|
|||||||
c.strength_type = self.strength_type
|
c.strength_type = self.strength_type
|
||||||
c.concat_mask = self.concat_mask
|
c.concat_mask = self.concat_mask
|
||||||
c.extra_concat_orig = self.extra_concat_orig.copy()
|
c.extra_concat_orig = self.extra_concat_orig.copy()
|
||||||
|
c.preprocess_image = self.preprocess_image
|
||||||
|
|
||||||
def inference_memory_requirements(self, dtype):
|
def inference_memory_requirements(self, dtype):
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@@ -181,7 +183,7 @@ class ControlBase:
|
|||||||
|
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
|
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.control_model = control_model
|
self.control_model = control_model
|
||||||
self.load_device = load_device
|
self.load_device = load_device
|
||||||
@@ -196,6 +198,7 @@ class ControlNet(ControlBase):
|
|||||||
self.extra_conds += extra_conds
|
self.extra_conds += extra_conds
|
||||||
self.strength_type = strength_type
|
self.strength_type = strength_type
|
||||||
self.concat_mask = concat_mask
|
self.concat_mask = concat_mask
|
||||||
|
self.preprocess_image = preprocess_image
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
@@ -224,6 +227,7 @@ class ControlNet(ControlBase):
|
|||||||
if self.latent_format is not None:
|
if self.latent_format is not None:
|
||||||
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
|
||||||
|
self.cond_hint = self.preprocess_image(self.cond_hint)
|
||||||
if self.vae is not None:
|
if self.vae is not None:
|
||||||
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
|
||||||
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
|
||||||
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
|
|||||||
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
logging.debug("unexpected controlnet keys: {}".format(unexpected))
|
||||||
return control_model
|
return control_model
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_mmdit(sd, model_options={}):
|
def load_controlnet_mmdit(sd, model_options={}):
|
||||||
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
|
||||||
@@ -448,6 +453,83 @@ def load_controlnet_mmdit(sd, model_options={}):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetSD35(ControlNet):
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
if self.control_model.double_y_emb:
|
||||||
|
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
|
||||||
|
else:
|
||||||
|
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def load_controlnet_sd35(sd, model_options={}):
|
||||||
|
control_type = -1
|
||||||
|
if "control_type" in sd:
|
||||||
|
control_type = round(sd.pop("control_type").item())
|
||||||
|
|
||||||
|
# blur_cnet = control_type == 0
|
||||||
|
canny_cnet = control_type == 1
|
||||||
|
depth_cnet = control_type == 2
|
||||||
|
|
||||||
|
print(control_type, canny_cnet, depth_cnet)
|
||||||
|
new_sd = {}
|
||||||
|
for k in comfy.utils.MMDIT_MAP_BASIC:
|
||||||
|
if k[1] in sd:
|
||||||
|
new_sd[k[0]] = sd.pop(k[1])
|
||||||
|
for k in sd:
|
||||||
|
new_sd[k] = sd[k]
|
||||||
|
sd = new_sd
|
||||||
|
|
||||||
|
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
|
||||||
|
depth = y_emb_shape[0] // 64
|
||||||
|
hidden_size = 64 * depth
|
||||||
|
num_heads = depth
|
||||||
|
head_dim = hidden_size // num_heads
|
||||||
|
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
|
||||||
|
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
offload_device = comfy.model_management.unet_offload_device()
|
||||||
|
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
|
||||||
|
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
|
||||||
|
patch_size=2,
|
||||||
|
in_chans=16,
|
||||||
|
num_layers=num_blocks,
|
||||||
|
main_model_double=depth,
|
||||||
|
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
|
||||||
|
attention_head_dim=head_dim,
|
||||||
|
num_attention_heads=num_heads,
|
||||||
|
adm_in_channels=2048,
|
||||||
|
device=offload_device,
|
||||||
|
dtype=unet_dtype,
|
||||||
|
operations=operations)
|
||||||
|
|
||||||
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.SD3()
|
||||||
|
preprocess_image = lambda a: a
|
||||||
|
if canny_cnet:
|
||||||
|
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
|
||||||
|
elif depth_cnet:
|
||||||
|
preprocess_image = lambda a: 1.0 - a
|
||||||
|
|
||||||
|
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
|
||||||
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
@@ -560,7 +642,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||||
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
|
||||||
elif "pos_embed_input.proj.weight" in controlnet_data:
|
elif "pos_embed_input.proj.weight" in controlnet_data:
|
||||||
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
|
||||||
|
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
|
||||||
|
else:
|
||||||
|
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
style=None,
|
style=None,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
control=None,
|
control=None,
|
||||||
transformer_options=None,
|
transformer_options={},
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass of the encoder.
|
Forward pass of the encoder.
|
||||||
@@ -315,8 +315,7 @@ class HunYuanDiT(nn.Module):
|
|||||||
return_dict: bool
|
return_dict: bool
|
||||||
Whether to return a dictionary.
|
Whether to return a dictionary.
|
||||||
"""
|
"""
|
||||||
#import pdb
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
#pdb.set_trace()
|
|
||||||
encoder_hidden_states = context
|
encoder_hidden_states = context
|
||||||
text_states = encoder_hidden_states # 2,77,1024
|
text_states = encoder_hidden_states # 2,77,1024
|
||||||
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
|
||||||
@@ -364,6 +363,8 @@ class HunYuanDiT(nn.Module):
|
|||||||
# Concatenate all extra vectors
|
# Concatenate all extra vectors
|
||||||
c = t + self.extra_embedder(extra_vec) # [B, D]
|
c = t + self.extra_embedder(extra_vec) # [B, D]
|
||||||
|
|
||||||
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
|
||||||
controls = None
|
controls = None
|
||||||
if control:
|
if control:
|
||||||
controls = control.get("output", None)
|
controls = control.get("output", None)
|
||||||
@@ -375,9 +376,20 @@ class HunYuanDiT(nn.Module):
|
|||||||
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
|
||||||
else:
|
else:
|
||||||
skip = skips.pop()
|
skip = skips.pop()
|
||||||
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
|
||||||
else:
|
else:
|
||||||
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
|
skip = None
|
||||||
|
|
||||||
|
if ("double_block", layer) in blocks_replace:
|
||||||
|
def block_wrap(args):
|
||||||
|
out = {}
|
||||||
|
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
|
||||||
|
|
||||||
|
|
||||||
if layer < (self.depth // 2 - 1):
|
if layer < (self.depth // 2 - 1):
|
||||||
skips.append(x)
|
skips.append(x)
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
|
||||||
|
|
||||||
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None):
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None] + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
|
||||||
|
|
||||||
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe) * gate_msa
|
||||||
|
|
||||||
@@ -415,13 +415,15 @@ class LTXVModel(torch.nn.Module):
|
|||||||
|
|
||||||
self.patchifier = SymmetricPatchifier(1)
|
self.patchifier = SymmetricPatchifier(1)
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, **kwargs):
|
def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, transformer_options={}, **kwargs):
|
||||||
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
indices_grid = self.patchifier.get_grid(
|
indices_grid = self.patchifier.get_grid(
|
||||||
orig_num_frames=x.shape[2],
|
orig_num_frames=x.shape[2],
|
||||||
orig_height=x.shape[3],
|
orig_height=x.shape[3],
|
||||||
orig_width=x.shape[4],
|
orig_width=x.shape[4],
|
||||||
batch_size=x.shape[0],
|
batch_size=x.shape[0],
|
||||||
scale_grid=((1 / frame_rate) * 8, 32, 32), #TODO: controlable frame rate
|
scale_grid=((1 / frame_rate) * 8, 32, 32),
|
||||||
device=x.device,
|
device=x.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -468,18 +470,28 @@ class LTXVModel(torch.nn.Module):
|
|||||||
batch_size, -1, x.shape[-1]
|
batch_size, -1, x.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for block in self.transformer_blocks:
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
x = block(
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
x,
|
if ("double_block", i) in blocks_replace:
|
||||||
context=context,
|
def block_wrap(args):
|
||||||
attention_mask=attention_mask,
|
out = {}
|
||||||
timestep=timestep,
|
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"])
|
||||||
pe=pe
|
return out
|
||||||
)
|
|
||||||
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe}, {"original_block": block_wrap})
|
||||||
|
x = out["img"]
|
||||||
|
else:
|
||||||
|
x = block(
|
||||||
|
x,
|
||||||
|
context=context,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
timestep=timestep,
|
||||||
|
pe=pe
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Output
|
# 3. Output
|
||||||
scale_shift_values = (
|
scale_shift_values = (
|
||||||
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + embedded_timestep[:, :, None]
|
||||||
)
|
)
|
||||||
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
||||||
x = self.norm_out(x)
|
x = self.norm_out(x)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ from typing import Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
|
|
||||||
class CausalConv3d(nn.Module):
|
class CausalConv3d(nn.Module):
|
||||||
@@ -29,7 +31,7 @@ class CausalConv3d(nn.Module):
|
|||||||
width_pad = kernel_size[2] // 2
|
width_pad = kernel_size[2] // 2
|
||||||
padding = (0, height_pad, width_pad)
|
padding = (0, height_pad, width_pad)
|
||||||
|
|
||||||
self.conv = nn.Conv3d(
|
self.conv = ops.Conv3d(
|
||||||
in_channels,
|
in_channels,
|
||||||
out_channels,
|
out_channels,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
|
|||||||
@@ -628,10 +628,10 @@ class processor(nn.Module):
|
|||||||
self.register_buffer("channel", torch.empty(128))
|
self.register_buffer("channel", torch.empty(128))
|
||||||
|
|
||||||
def un_normalize(self, x):
|
def un_normalize(self, x):
|
||||||
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)
|
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
def normalize(self, x):
|
def normalize(self, x):
|
||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ import torch
|
|||||||
|
|
||||||
from .dual_conv3d import DualConv3d
|
from .dual_conv3d import DualConv3d
|
||||||
from .causal_conv3d import CausalConv3d
|
from .causal_conv3d import CausalConv3d
|
||||||
|
import comfy.ops
|
||||||
|
ops = comfy.ops.disable_weight_init
|
||||||
|
|
||||||
def make_conv_nd(
|
def make_conv_nd(
|
||||||
dims: Union[int, Tuple[int, int]],
|
dims: Union[int, Tuple[int, int]],
|
||||||
@@ -19,7 +20,7 @@ def make_conv_nd(
|
|||||||
causal=False,
|
causal=False,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
@@ -41,7 +42,7 @@ def make_conv_nd(
|
|||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
)
|
)
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
@@ -71,11 +72,11 @@ def make_linear_nd(
|
|||||||
bias=True,
|
bias=True,
|
||||||
):
|
):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return torch.nn.Conv2d(
|
return ops.Conv2d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
elif dims == 3 or dims == (2, 1):
|
elif dims == 3 or dims == (2, 1):
|
||||||
return torch.nn.Conv3d(
|
return ops.Conv3d(
|
||||||
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ def load_lora(lora, to_load):
|
|||||||
diffusers_lora = "{}_lora.up.weight".format(x)
|
diffusers_lora = "{}_lora.up.weight".format(x)
|
||||||
diffusers2_lora = "{}.lora_B.weight".format(x)
|
diffusers2_lora = "{}.lora_B.weight".format(x)
|
||||||
diffusers3_lora = "{}.lora.up.weight".format(x)
|
diffusers3_lora = "{}.lora.up.weight".format(x)
|
||||||
|
mochi_lora = "{}.lora_B".format(x)
|
||||||
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
|
||||||
A_name = None
|
A_name = None
|
||||||
|
|
||||||
@@ -81,6 +82,10 @@ def load_lora(lora, to_load):
|
|||||||
A_name = diffusers3_lora
|
A_name = diffusers3_lora
|
||||||
B_name = "{}.lora.down.weight".format(x)
|
B_name = "{}.lora.down.weight".format(x)
|
||||||
mid_name = None
|
mid_name = None
|
||||||
|
elif mochi_lora in lora.keys():
|
||||||
|
A_name = mochi_lora
|
||||||
|
B_name = "{}.lora_A".format(x)
|
||||||
|
mid_name = None
|
||||||
elif transformers_lora in lora.keys():
|
elif transformers_lora in lora.keys():
|
||||||
A_name = transformers_lora
|
A_name = transformers_lora
|
||||||
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
B_name ="{}.lora_linear_layer.down.weight".format(x)
|
||||||
@@ -362,6 +367,12 @@ def model_lora_keys_unet(model, key_map={}):
|
|||||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||||
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
|
||||||
|
|
||||||
|
if isinstance(model, comfy.model_base.GenmoMochi):
|
||||||
|
for k in sdk:
|
||||||
|
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
|
||||||
|
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||||
|
key_map["{}".format(key_lora)] = k
|
||||||
|
|
||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -628,6 +628,10 @@ def maximum_vram_for_weights(device=None):
|
|||||||
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
if model_params < 0:
|
if model_params < 0:
|
||||||
model_params = 1000000000000000000000
|
model_params = 1000000000000000000000
|
||||||
|
if args.fp32_unet:
|
||||||
|
return torch.float32
|
||||||
|
if args.fp64_unet:
|
||||||
|
return torch.float64
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
if args.fp16_unet:
|
if args.fp16_unet:
|
||||||
@@ -674,7 +678,7 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
|
|||||||
|
|
||||||
# None means no manual cast
|
# None means no manual cast
|
||||||
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
|
||||||
if weight_dtype == torch.float32:
|
if weight_dtype == torch.float32 or weight_dtype == torch.float64:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
fp16_supported = should_use_fp16(inference_device, prioritize_performance=False)
|
||||||
|
|||||||
@@ -367,10 +367,7 @@ class ModelPatcher:
|
|||||||
else:
|
else:
|
||||||
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key))
|
||||||
|
|
||||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
def _load_list(self):
|
||||||
mem_counter = 0
|
|
||||||
patch_counter = 0
|
|
||||||
lowvram_counter = 0
|
|
||||||
loading = []
|
loading = []
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
params = []
|
params = []
|
||||||
@@ -383,6 +380,13 @@ class ModelPatcher:
|
|||||||
break
|
break
|
||||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||||
|
return loading
|
||||||
|
|
||||||
|
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
|
mem_counter = 0
|
||||||
|
patch_counter = 0
|
||||||
|
lowvram_counter = 0
|
||||||
|
loading = self._load_list()
|
||||||
|
|
||||||
load_completely = []
|
load_completely = []
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
@@ -514,14 +518,7 @@ class ModelPatcher:
|
|||||||
def partially_unload(self, device_to, memory_to_free=0):
|
def partially_unload(self, device_to, memory_to_free=0):
|
||||||
memory_freed = 0
|
memory_freed = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
unload_list = []
|
unload_list = self._load_list()
|
||||||
|
|
||||||
for n, m in self.model.named_modules():
|
|
||||||
shift_lowvram = False
|
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
module_mem = comfy.model_management.module_size(m)
|
|
||||||
unload_list.append((module_mem, n, m))
|
|
||||||
|
|
||||||
unload_list.sort()
|
unload_list.sort()
|
||||||
for unload in unload_list:
|
for unload in unload_list:
|
||||||
if memory_to_free < memory_freed:
|
if memory_to_free < memory_freed:
|
||||||
@@ -529,32 +526,42 @@ class ModelPatcher:
|
|||||||
module_mem = unload[0]
|
module_mem = unload[0]
|
||||||
n = unload[1]
|
n = unload[1]
|
||||||
m = unload[2]
|
m = unload[2]
|
||||||
weight_key = "{}.weight".format(n)
|
params = unload[3]
|
||||||
bias_key = "{}.bias".format(n)
|
|
||||||
|
|
||||||
|
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||||
for key in [weight_key, bias_key]:
|
move_weight = True
|
||||||
|
for param in params:
|
||||||
|
key = "{}.{}".format(n, param)
|
||||||
bk = self.backup.get(key, None)
|
bk = self.backup.get(key, None)
|
||||||
if bk is not None:
|
if bk is not None:
|
||||||
|
if not lowvram_possible:
|
||||||
|
move_weight = False
|
||||||
|
break
|
||||||
|
|
||||||
if bk.inplace_update:
|
if bk.inplace_update:
|
||||||
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
comfy.utils.copy_to_param(self.model, key, bk.weight)
|
||||||
else:
|
else:
|
||||||
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
comfy.utils.set_attr_param(self.model, key, bk.weight)
|
||||||
self.backup.pop(key)
|
self.backup.pop(key)
|
||||||
|
|
||||||
m.to(device_to)
|
weight_key = "{}.weight".format(n)
|
||||||
if weight_key in self.patches:
|
bias_key = "{}.bias".format(n)
|
||||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
if move_weight:
|
||||||
patch_counter += 1
|
m.to(device_to)
|
||||||
if bias_key in self.patches:
|
if lowvram_possible:
|
||||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
if weight_key in self.patches:
|
||||||
patch_counter += 1
|
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||||
|
patch_counter += 1
|
||||||
|
if bias_key in self.patches:
|
||||||
|
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||||
|
patch_counter += 1
|
||||||
|
|
||||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||||
m.comfy_cast_weights = True
|
m.comfy_cast_weights = True
|
||||||
m.comfy_patched_weights = False
|
m.comfy_patched_weights = False
|
||||||
memory_freed += module_mem
|
memory_freed += module_mem
|
||||||
logging.debug("freed {}".format(n))
|
logging.debug("freed {}".format(n))
|
||||||
|
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
self.model.lowvram_patch_counter += patch_counter
|
self.model.lowvram_patch_counter += patch_counter
|
||||||
|
|||||||
12
comfy/sd.py
12
comfy/sd.py
@@ -269,7 +269,7 @@ class VAE:
|
|||||||
self.latent_dim = 3
|
self.latent_dim = 3
|
||||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||||
self.upscale_ratio = 8
|
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
else:
|
else:
|
||||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||||
@@ -370,7 +370,9 @@ class VAE:
|
|||||||
elif dims == 2:
|
elif dims == 2:
|
||||||
pixel_samples = self.decode_tiled_(samples_in)
|
pixel_samples = self.decode_tiled_(samples_in)
|
||||||
elif dims == 3:
|
elif dims == 3:
|
||||||
pixel_samples = self.decode_tiled_3d(samples_in)
|
tile = 256 // self.spacial_compression_decode()
|
||||||
|
overlap = tile // 4
|
||||||
|
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||||
|
|
||||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||||
return pixel_samples
|
return pixel_samples
|
||||||
@@ -434,6 +436,12 @@ class VAE:
|
|||||||
def get_sd(self):
|
def get_sd(self):
|
||||||
return self.first_stage_model.state_dict()
|
return self.first_stage_model.state_dict()
|
||||||
|
|
||||||
|
def spacial_compression_decode(self):
|
||||||
|
try:
|
||||||
|
return self.upscale_ratio[-1]
|
||||||
|
except:
|
||||||
|
return self.upscale_ratio
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@@ -659,6 +659,15 @@ class Flux(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||||
|
|
||||||
|
class FluxInpaint(Flux):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "flux",
|
||||||
|
"guidance_embed": True,
|
||||||
|
"in_channels": 96,
|
||||||
|
}
|
||||||
|
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
class FluxSchnell(Flux):
|
class FluxSchnell(Flux):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "flux",
|
"image_model": "flux",
|
||||||
@@ -731,6 +740,6 @@ class LTXV(supported_models_base.BASE):
|
|||||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
|
||||||
|
|
||||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV]
|
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||||
|
|
||||||
models += [SVD_img2vid]
|
models += [SVD_img2vid]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ class EmptyLTXVLatentVideo:
|
|||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
return {"required": { "width": ("INT", {"default": 768, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
"height": ("INT", {"default": 512, "min": 64, "max": nodes.MAX_RESOLUTION, "step": 32}),
|
||||||
"length": ("INT", {"default": 97, "min": 9, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
"length": ("INT", {"default": 97, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
|
||||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})}}
|
||||||
RETURN_TYPES = ("LATENT",)
|
RETURN_TYPES = ("LATENT",)
|
||||||
FUNCTION = "generate"
|
FUNCTION = "generate"
|
||||||
|
|||||||
3
nodes.py
3
nodes.py
@@ -301,7 +301,8 @@ class VAEDecodeTiled:
|
|||||||
def decode(self, vae, samples, tile_size, overlap=64):
|
def decode(self, vae, samples, tile_size, overlap=64):
|
||||||
if tile_size < overlap * 4:
|
if tile_size < overlap * 4:
|
||||||
overlap = tile_size // 4
|
overlap = tile_size // 4
|
||||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // 8, tile_y=tile_size // 8, overlap=overlap // 8)
|
compression = vae.spacial_compression_decode()
|
||||||
|
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||||
if len(images.shape) == 5: #Combine batches
|
if len(images.shape) == 5: #Combine batches
|
||||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||||
return (images, )
|
return (images, )
|
||||||
|
|||||||
Reference in New Issue
Block a user