Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39fb74c5bd | ||
|
|
74e124f4d7 | ||
|
|
a562c17e8a | ||
|
|
5942c17d55 | ||
|
|
c032b11e07 |
@@ -34,6 +34,8 @@ import comfy.t2i_adapter.adapter
|
||||
import comfy.ldm.cascade.controlnet
|
||||
import comfy.cldm.mmdit
|
||||
import comfy.ldm.hydit.controlnet
|
||||
import comfy.ldm.flux.controlnet_xlabs
|
||||
|
||||
|
||||
def broadcast_image_to(tensor, target_batch_size, batched_number):
|
||||
current_batch_size = tensor.shape[0]
|
||||
@@ -416,6 +418,7 @@ def load_controlnet_mmdit(sd):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
|
||||
@@ -427,6 +430,15 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds, strength_type=StrengthType.CONSTANT)
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
||||
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
return control
|
||||
|
||||
|
||||
def load_controlnet(ckpt_path, model=None):
|
||||
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
||||
if 'after_proj_list.18.bias' in controlnet_data.keys(): #Hunyuan DiT
|
||||
@@ -489,7 +501,10 @@ def load_controlnet(ckpt_path, model=None):
|
||||
logging.warning("leftover keys: {}".format(leftover_keys))
|
||||
controlnet_data = new_sd
|
||||
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
|
||||
return load_controlnet_flux_xlabs(controlnet_data)
|
||||
else:
|
||||
return load_controlnet_mmdit(controlnet_data)
|
||||
|
||||
pth_key = 'control_model.zero_convs.0.0.weight'
|
||||
pth = False
|
||||
|
||||
@@ -22,7 +22,7 @@ def load_diffusers(model_path, output_vae=True, output_clip=True, embedding_dire
|
||||
if text_encoder2_path is not None:
|
||||
text_encoder_paths.append(text_encoder2_path)
|
||||
|
||||
unet = comfy.sd.load_unet(unet_path)
|
||||
unet = comfy.sd.load_diffusion_model(unet_path)
|
||||
|
||||
clip = None
|
||||
if output_clip:
|
||||
|
||||
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
104
comfy/ldm/flux/controlnet_xlabs.py
Normal file
@@ -0,0 +1,104 @@
|
||||
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
||||
MLPEmbedder, SingleStreamBlock,
|
||||
timestep_embedding)
|
||||
|
||||
from .model import Flux
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
class ControlNetFlux(Flux):
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
|
||||
|
||||
# add ControlNet blocks
|
||||
self.controlnet_blocks = nn.ModuleList([])
|
||||
for _ in range(self.params.depth):
|
||||
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
|
||||
# controlnet_block = zero_module(controlnet_block)
|
||||
self.controlnet_blocks.append(controlnet_block)
|
||||
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.gradient_checkpointing = False
|
||||
self.input_hint_block = nn.Sequential(
|
||||
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
controlnet_cond: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
controlnet_cond = self.input_hint_block(controlnet_cond)
|
||||
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
||||
controlnet_cond = self.pos_embed_input(controlnet_cond)
|
||||
img = img + controlnet_cond
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256))
|
||||
if self.params.guidance_embed:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
||||
vec = vec + self.vector_in(y)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
block_res_samples = ()
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
block_res_samples = block_res_samples + (img,)
|
||||
|
||||
controlnet_block_res_samples = ()
|
||||
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
||||
block_res_sample = controlnet_block(block_res_sample)
|
||||
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
||||
|
||||
return {"output": (controlnet_block_res_samples * 10)[:19]}
|
||||
|
||||
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
|
||||
hint = hint * 2.0 - 1.0
|
||||
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
|
||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||
|
||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)
|
||||
@@ -170,8 +170,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
|
||||
# calculate the img bloks
|
||||
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
|
||||
@@ -38,7 +38,7 @@ class Flux(nn.Module):
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = FluxParams(**kwargs)
|
||||
@@ -83,7 +83,8 @@ class Flux(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@@ -94,6 +95,7 @@ class Flux(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
) -> Tensor:
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
||||
@@ -112,8 +114,15 @@ class Flux(nn.Module):
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
for block in self.double_blocks:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
for i in range(len(self.double_blocks)):
|
||||
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
|
||||
|
||||
if control is not None: #Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
for block in self.single_blocks:
|
||||
@@ -123,7 +132,7 @@ class Flux(nn.Module):
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, **kwargs):
|
||||
def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
|
||||
bs, c, h, w = x.shape
|
||||
patch_size = 2
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
||||
@@ -138,5 +147,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
|
||||
@@ -338,8 +338,9 @@ class LoadedModel:
|
||||
def model_unload(self, memory_to_free=None, unpatch_weights=True):
|
||||
if memory_to_free is not None:
|
||||
if memory_to_free < self.model.loaded_size():
|
||||
self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
return False
|
||||
freed = self.model.partially_unload(self.model.offload_device, memory_to_free)
|
||||
if freed >= memory_to_free:
|
||||
return False
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
self.weights_loaded = self.weights_loaded and not unpatch_weights
|
||||
@@ -434,7 +435,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
return unloaded_models
|
||||
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
@@ -513,7 +514,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
else:
|
||||
vram_set_state = vram_state
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM):
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
|
||||
@@ -411,7 +411,7 @@ class ModelPatcher:
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
|
||||
21
comfy/sd.py
21
comfy/sd.py
@@ -86,7 +86,7 @@ class CLIP:
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
if params['device'] == load_device:
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
|
||||
@@ -585,12 +585,13 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
model_patcher = comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=model_management.unet_offload_device())
|
||||
if inital_load_device != torch.device("cpu"):
|
||||
logging.info("loaded straight to GPU")
|
||||
model_management.load_model_gpu(model_patcher)
|
||||
model_management.load_models_gpu([model_patcher], force_full_load=True)
|
||||
|
||||
return (model_patcher, clip, vae, clipvision)
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular format
|
||||
def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffusers or regular format
|
||||
dtype = model_options.get("dtype", None)
|
||||
|
||||
#Allow loading unets from checkpoint files
|
||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||
@@ -632,6 +633,7 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
model = model_config.get_model(new_sd, "")
|
||||
model = model.to(offload_device)
|
||||
model.load_model_weights(new_sd, "")
|
||||
@@ -640,14 +642,23 @@ def load_unet_state_dict(sd, dtype=None): #load unet in diffusers or regular for
|
||||
logging.info("left over keys in unet: {}".format(left_over))
|
||||
return comfy.model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
|
||||
def load_diffusion_model(unet_path, model_options={}):
|
||||
sd = comfy.utils.load_torch_file(unet_path)
|
||||
model = load_unet_state_dict(sd, dtype=dtype)
|
||||
model = load_diffusion_model_state_dict(sd, model_options=model_options)
|
||||
if model is None:
|
||||
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
|
||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))
|
||||
return model
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}):
|
||||
clip_sd = None
|
||||
load_models = [model]
|
||||
|
||||
8
nodes.py
8
nodes.py
@@ -826,14 +826,14 @@ class UNETLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name, weight_dtype):
|
||||
dtype = None
|
||||
model_options = {}
|
||||
if weight_dtype == "fp8_e4m3fn":
|
||||
dtype = torch.float8_e4m3fn
|
||||
model_options["dtype"] = torch.float8_e4m3fn
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
dtype = torch.float8_e5m2
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
model = comfy.sd.load_unet(unet_path, dtype=dtype)
|
||||
model = comfy.sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
|
||||
Reference in New Issue
Block a user