Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8ffb2937f | ||
|
|
ce37c11164 | ||
|
|
b5c3906b38 | ||
|
|
5d43e75e5b | ||
|
|
517f4a94e4 | ||
|
|
52a471c5c7 | ||
|
|
ad76574cb8 | ||
|
|
9acfe4df41 | ||
|
|
9829b013ea | ||
|
|
5c69cde037 | ||
|
|
e9589d6d92 | ||
|
|
0d82a798a5 | ||
|
|
925fff26fd | ||
|
|
75b9b55b22 | ||
|
|
1765f1c60c | ||
|
|
1de69fe4d5 | ||
|
|
ae197f651b |
16
.github/workflows/pullrequest-ci-run.yml
vendored
16
.github/workflows/pullrequest-ci-run.yml
vendored
@@ -35,3 +35,19 @@ jobs:
|
|||||||
torch_version: ${{ matrix.torch_version }}
|
torch_version: ${{ matrix.torch_version }}
|
||||||
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
google_credentials: ${{ secrets.GCS_SERVICE_ACCOUNT_JSON }}
|
||||||
comfyui_flags: ${{ matrix.flags }}
|
comfyui_flags: ${{ matrix.flags }}
|
||||||
|
use_prior_commit: 'true'
|
||||||
|
comment:
|
||||||
|
if: ${{ github.event.label.name == 'Run-CI-Test' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
pull-requests: write
|
||||||
|
steps:
|
||||||
|
- uses: actions/github-script@v6
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
github.rest.issues.createComment({
|
||||||
|
issue_number: context.issue.number,
|
||||||
|
owner: context.repo.owner,
|
||||||
|
repo: context.repo.repo,
|
||||||
|
body: '(Automated Bot Message) CI Tests are running, you can view the results at https://ci.comfy.org/?branch=${{ github.event.pull_request.number }}%2Fmerge'
|
||||||
|
})
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -19,3 +19,4 @@ venv/
|
|||||||
/user/
|
/user/
|
||||||
*.log
|
*.log
|
||||||
web_custom_versions/
|
web_custom_versions/
|
||||||
|
.DS_Store
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x
|
|||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
def rotate_half(x):
|
||||||
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
||||||
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
||||||
|
|
||||||
|
|
||||||
@@ -78,10 +78,9 @@ def apply_rotary_emb(
|
|||||||
xk_out = None
|
xk_out = None
|
||||||
if isinstance(freqs_cis, tuple):
|
if isinstance(freqs_cis, tuple):
|
||||||
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
||||||
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
xq_out = (xq * cos + rotate_half(xq) * sin)
|
||||||
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
|
||||||
if xk is not None:
|
if xk is not None:
|
||||||
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
xk_out = (xk * cos + rotate_half(xk) * sin)
|
||||||
else:
|
else:
|
||||||
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
|
||||||
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ def calc_rope(x, patch_size, head_size):
|
|||||||
sub_args = [start, stop, (th, tw)]
|
sub_args = [start, stop, (th, tw)]
|
||||||
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
# head_size = HUNYUAN_DIT_CONFIG['DiT-g/2']['hidden_size'] // HUNYUAN_DIT_CONFIG['DiT-g/2']['num_heads']
|
||||||
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
rope = get_2d_rotary_pos_embed(head_size, *sub_args)
|
||||||
|
rope = (rope[0].to(x), rope[1].to(x))
|
||||||
return rope
|
return rope
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
@@ -77,10 +95,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if not unet_config.get("disable_unet_model_creation", False):
|
if not unet_config.get("disable_unet_model_creation", False):
|
||||||
if self.manual_cast_dtype is not None:
|
if model_config.custom_operations is None:
|
||||||
operations = comfy.ops.manual_cast
|
if self.manual_cast_dtype is not None:
|
||||||
|
operations = comfy.ops.manual_cast
|
||||||
|
else:
|
||||||
|
operations = comfy.ops.disable_weight_init
|
||||||
else:
|
else:
|
||||||
operations = comfy.ops.disable_weight_init
|
operations = model_config.custom_operations
|
||||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||||
if comfy.model_management.force_channels_last():
|
if comfy.model_management.force_channels_last():
|
||||||
self.diffusion_model.to(memory_format=torch.channels_last)
|
self.diffusion_model.to(memory_format=torch.channels_last)
|
||||||
|
|||||||
@@ -495,7 +495,12 @@ def model_config_from_diffusers_unet(state_dict):
|
|||||||
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||||
out_sd = {}
|
out_sd = {}
|
||||||
|
|
||||||
if 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
if 'transformer_blocks.0.attn.norm_added_k.weight' in state_dict: #Flux
|
||||||
|
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
|
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||||
|
hidden_size = state_dict["x_embedder.bias"].shape[0]
|
||||||
|
sd_map = comfy.utils.flux_to_diffusers({"depth": depth, "depth_single_blocks": depth_single_blocks, "hidden_size": hidden_size}, output_prefix=output_prefix)
|
||||||
|
elif 'transformer_blocks.0.attn.add_q_proj.weight' in state_dict: #SD3
|
||||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||||
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
depth = state_dict["pos_embed.proj.weight"].shape[0] // 64
|
||||||
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
sd_map = comfy.utils.mmdit_to_diffusers({"depth": depth, "num_blocks": num_blocks}, output_prefix=output_prefix)
|
||||||
@@ -521,7 +526,12 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
|||||||
old_weight = out_sd.get(t[0], None)
|
old_weight = out_sd.get(t[0], None)
|
||||||
if old_weight is None:
|
if old_weight is None:
|
||||||
old_weight = torch.empty_like(weight)
|
old_weight = torch.empty_like(weight)
|
||||||
old_weight = old_weight.repeat([3] + [1] * (len(old_weight.shape) - 1))
|
if old_weight.shape[offset[0]] < offset[1] + offset[2]:
|
||||||
|
exp = list(weight.shape)
|
||||||
|
exp[offset[0]] = offset[1] + offset[2]
|
||||||
|
new = torch.empty(exp, device=weight.device, dtype=weight.dtype)
|
||||||
|
new[:old_weight.shape[0]] = old_weight
|
||||||
|
old_weight = new
|
||||||
|
|
||||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -296,7 +296,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
def model_memory_required(self, device):
|
def model_memory_required(self, device):
|
||||||
if device == self.model.current_loaded_device():
|
if device == self.model.current_loaded_device():
|
||||||
return 0
|
return self.model_offloaded_memory()
|
||||||
else:
|
else:
|
||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
@@ -308,15 +308,21 @@ class LoadedModel:
|
|||||||
|
|
||||||
load_weights = not self.weights_loaded
|
load_weights = not self.weights_loaded
|
||||||
|
|
||||||
try:
|
if self.model.loaded_size() > 0:
|
||||||
if lowvram_model_memory > 0 and load_weights:
|
use_more_vram = lowvram_model_memory
|
||||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
if use_more_vram == 0:
|
||||||
else:
|
use_more_vram = 1e32
|
||||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
self.model_use_more_vram(use_more_vram)
|
||||||
except Exception as e:
|
else:
|
||||||
self.model.unpatch_model(self.model.offload_device)
|
try:
|
||||||
self.model_unload()
|
if lowvram_model_memory > 0 and load_weights:
|
||||||
raise e
|
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||||
|
else:
|
||||||
|
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||||
|
except Exception as e:
|
||||||
|
self.model.unpatch_model(self.model.offload_device)
|
||||||
|
self.model_unload()
|
||||||
|
raise e
|
||||||
|
|
||||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||||
@@ -432,11 +438,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
global vram_state
|
global vram_state
|
||||||
|
|
||||||
inference_memory = minimum_inference_memory()
|
inference_memory = minimum_inference_memory()
|
||||||
extra_mem = max(inference_memory, memory_required) + 100 * 1024 * 1024
|
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||||
if minimum_memory_required is None:
|
if minimum_memory_required is None:
|
||||||
minimum_memory_required = extra_mem
|
minimum_memory_required = extra_mem
|
||||||
else:
|
else:
|
||||||
minimum_memory_required = max(inference_memory, minimum_memory_required) + 100 * 1024 * 1024
|
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||||
|
|
||||||
models = set(models)
|
models = set(models)
|
||||||
|
|
||||||
@@ -484,18 +490,21 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
|
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
|
|
||||||
for device in total_memory_required:
|
for loaded_model in models_already_loaded:
|
||||||
if device != torch.device("cpu"):
|
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||||
free_memory(total_memory_required[device] * 1.3 + extra_mem, device, models_already_loaded)
|
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
weights_unloaded = unload_model_clones(loaded_model.model, unload_weights_only=False, force_unload=False) #unload the rest of the clones where the weights can stay loaded
|
||||||
if weights_unloaded is not None:
|
if weights_unloaded is not None:
|
||||||
loaded_model.weights_loaded = not weights_unloaded
|
loaded_model.weights_loaded = not weights_unloaded
|
||||||
|
|
||||||
|
for device in total_memory_required:
|
||||||
|
if device != torch.device("cpu"):
|
||||||
|
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, models_already_loaded)
|
||||||
|
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
model = loaded_model.model
|
model = loaded_model.model
|
||||||
torch_dev = model.load_device
|
torch_dev = model.load_device
|
||||||
@@ -675,6 +684,20 @@ def text_encoder_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||||
|
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
if is_device_mps(load_device):
|
||||||
|
return offload_device
|
||||||
|
|
||||||
|
mem_l = get_free_memory(load_device)
|
||||||
|
mem_o = get_free_memory(offload_device)
|
||||||
|
if mem_l > (mem_o * 0.5) and model_size * 1.2 < mem_l:
|
||||||
|
return load_device
|
||||||
|
else:
|
||||||
|
return offload_device
|
||||||
|
|
||||||
def text_encoder_dtype(device=None):
|
def text_encoder_dtype(device=None):
|
||||||
if args.fp8_e4m3fn_text_enc:
|
if args.fp8_e4m3fn_text_enc:
|
||||||
return torch.float8_e4m3fn
|
return torch.float8_e4m3fn
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class ModelPatcher:
|
|||||||
self.size = size
|
self.size = size
|
||||||
self.model = model
|
self.model = model
|
||||||
if not hasattr(self.model, 'device'):
|
if not hasattr(self.model, 'device'):
|
||||||
logging.info("Model doesn't have a device attribute.")
|
logging.debug("Model doesn't have a device attribute.")
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
elif self.model.device is None:
|
elif self.model.device is None:
|
||||||
self.model.device = offload_device
|
self.model.device = offload_device
|
||||||
@@ -355,13 +355,14 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||||
mem_counter = 0
|
mem_counter = 0
|
||||||
patch_counter = 0
|
patch_counter = 0
|
||||||
lowvram_counter = 0
|
lowvram_counter = 0
|
||||||
for n, m in self.model.named_modules():
|
for n, m in self.model.named_modules():
|
||||||
lowvram_weight = False
|
lowvram_weight = False
|
||||||
if hasattr(m, "comfy_cast_weights"):
|
|
||||||
|
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||||
module_mem = comfy.model_management.module_size(m)
|
module_mem = comfy.model_management.module_size(m)
|
||||||
if mem_counter + module_mem >= lowvram_model_memory:
|
if mem_counter + module_mem >= lowvram_model_memory:
|
||||||
lowvram_weight = True
|
lowvram_weight = True
|
||||||
@@ -401,13 +402,16 @@ class ModelPatcher:
|
|||||||
if weight.device == device_to:
|
if weight.device == device_to:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
self.patch_weight_to_device(weight_key) #TODO: speed this up without OOM
|
weight_to = None
|
||||||
self.patch_weight_to_device(bias_key)
|
if full_load:#TODO
|
||||||
|
weight_to = device_to
|
||||||
|
self.patch_weight_to_device(weight_key, device_to=weight_to) #TODO: speed this up without OOM
|
||||||
|
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||||
m.to(device_to)
|
m.to(device_to)
|
||||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||||
|
|
||||||
if lowvram_counter > 0:
|
if lowvram_counter > 0:
|
||||||
logging.info("loaded in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
logging.info("loaded partially {} {}".format(lowvram_model_memory / (1024 * 1024), patch_counter))
|
||||||
self.model.model_lowvram = True
|
self.model.model_lowvram = True
|
||||||
else:
|
else:
|
||||||
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
logging.info("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||||
@@ -665,12 +669,15 @@ class ModelPatcher:
|
|||||||
return memory_freed
|
return memory_freed
|
||||||
|
|
||||||
def partially_load(self, device_to, extra_memory=0):
|
def partially_load(self, device_to, extra_memory=0):
|
||||||
|
self.unpatch_model(unpatch_weights=False)
|
||||||
|
self.patch_model(patch_weights=False)
|
||||||
|
full_load = False
|
||||||
if self.model.model_lowvram == False:
|
if self.model.model_lowvram == False:
|
||||||
return 0
|
return 0
|
||||||
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
if self.model.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||||
pass #TODO: Full load
|
full_load = True
|
||||||
current_used = self.model.model_loaded_weight_memory
|
current_used = self.model.model_loaded_weight_memory
|
||||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory)
|
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
def current_loaded_device(self):
|
def current_loaded_device(self):
|
||||||
|
|||||||
40
comfy/sd.py
40
comfy/sd.py
@@ -62,7 +62,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
|||||||
|
|
||||||
|
|
||||||
class CLIP:
|
class CLIP:
|
||||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}):
|
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0):
|
||||||
if no_init:
|
if no_init:
|
||||||
return
|
return
|
||||||
params = target.params.copy()
|
params = target.params.copy()
|
||||||
@@ -71,20 +71,24 @@ class CLIP:
|
|||||||
|
|
||||||
load_device = model_management.text_encoder_device()
|
load_device = model_management.text_encoder_device()
|
||||||
offload_device = model_management.text_encoder_offload_device()
|
offload_device = model_management.text_encoder_offload_device()
|
||||||
params['device'] = offload_device
|
|
||||||
dtype = model_management.text_encoder_dtype(load_device)
|
dtype = model_management.text_encoder_dtype(load_device)
|
||||||
params['dtype'] = dtype
|
params['dtype'] = dtype
|
||||||
|
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||||
self.cond_stage_model = clip(**(params))
|
self.cond_stage_model = clip(**(params))
|
||||||
|
|
||||||
for dt in self.cond_stage_model.dtypes:
|
for dt in self.cond_stage_model.dtypes:
|
||||||
if not model_management.supports_cast(load_device, dt):
|
if not model_management.supports_cast(load_device, dt):
|
||||||
load_device = offload_device
|
load_device = offload_device
|
||||||
|
if params['device'] != offload_device:
|
||||||
|
self.cond_stage_model.to(offload_device)
|
||||||
|
logging.warning("Had to shift TE back.")
|
||||||
|
|
||||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||||
|
if params['device'] == load_device:
|
||||||
|
model_management.load_model_gpu(self.patcher)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}".format(load_device, offload_device))
|
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@@ -456,7 +460,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
|||||||
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
clip_target.clip = comfy.text_encoders.sd3_clip.SD3ClipModel
|
||||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||||
|
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory)
|
parameters = 0
|
||||||
|
for c in clip_data:
|
||||||
|
parameters += comfy.utils.calculate_parameters(c)
|
||||||
|
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters)
|
||||||
for c in clip_data:
|
for c in clip_data:
|
||||||
m, u = clip.load_sd(c)
|
m, u = clip.load_sd(c)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
@@ -498,15 +506,19 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
|||||||
|
|
||||||
return (model, clip, vae)
|
return (model, clip, vae)
|
||||||
|
|
||||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True):
|
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||||
sd = comfy.utils.load_torch_file(ckpt_path)
|
sd = comfy.utils.load_torch_file(ckpt_path)
|
||||||
sd_keys = sd.keys()
|
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options)
|
||||||
|
if out is None:
|
||||||
|
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
||||||
|
return out
|
||||||
|
|
||||||
|
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||||
clip = None
|
clip = None
|
||||||
clipvision = None
|
clipvision = None
|
||||||
vae = None
|
vae = None
|
||||||
model = None
|
model = None
|
||||||
model_patcher = None
|
model_patcher = None
|
||||||
clip_target = None
|
|
||||||
|
|
||||||
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
diffusion_model_prefix = model_detection.unet_prefix_from_state_dict(sd)
|
||||||
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
parameters = comfy.utils.calculate_parameters(sd, diffusion_model_prefix)
|
||||||
@@ -515,13 +527,18 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
|
|
||||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
|
return None
|
||||||
|
|
||||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||||
if weight_dtype is not None:
|
if weight_dtype is not None:
|
||||||
unet_weight_dtype.append(weight_dtype)
|
unet_weight_dtype.append(weight_dtype)
|
||||||
|
|
||||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||||
|
unet_dtype = model_options.get("weight_dtype", None)
|
||||||
|
|
||||||
|
if unet_dtype is None:
|
||||||
|
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
|
||||||
|
|
||||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||||
|
|
||||||
@@ -545,7 +562,8 @@ def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, o
|
|||||||
if clip_target is not None:
|
if clip_target is not None:
|
||||||
clip_sd = model_config.process_clip_state_dict(sd)
|
clip_sd = model_config.process_clip_state_dict(sd)
|
||||||
if len(clip_sd) > 0:
|
if len(clip_sd) > 0:
|
||||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd)
|
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||||
|
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||||
if len(m) > 0:
|
if len(m) > 0:
|
||||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||||
|
|||||||
@@ -1,3 +1,21 @@
|
|||||||
|
"""
|
||||||
|
This file is part of ComfyUI.
|
||||||
|
Copyright (C) 2024 Comfy
|
||||||
|
|
||||||
|
This program is free software: you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation, either version 3 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License
|
||||||
|
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||||
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from . import model_base
|
from . import model_base
|
||||||
from . import utils
|
from . import utils
|
||||||
@@ -30,6 +48,7 @@ class BASE:
|
|||||||
memory_usage_factor = 2.0
|
memory_usage_factor = 2.0
|
||||||
|
|
||||||
manual_cast_dtype = None
|
manual_cast_dtype = None
|
||||||
|
custom_operations = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def matches(s, unet_config, state_dict=None):
|
def matches(s, unet_config, state_dict=None):
|
||||||
|
|||||||
@@ -457,8 +457,27 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}add_k_proj.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
|
|
||||||
block_map = {"attn.to_out.0.weight": "img_attn.proj.weight",
|
block_map = {
|
||||||
"attn.to_out.0.bias": "img_attn.proj.bias",
|
"attn.to_out.0.weight": "img_attn.proj.weight",
|
||||||
|
"attn.to_out.0.bias": "img_attn.proj.bias",
|
||||||
|
"norm1.linear.weight": "img_mod.lin.weight",
|
||||||
|
"norm1.linear.bias": "img_mod.lin.bias",
|
||||||
|
"norm1_context.linear.weight": "txt_mod.lin.weight",
|
||||||
|
"norm1_context.linear.bias": "txt_mod.lin.bias",
|
||||||
|
"attn.to_add_out.weight": "txt_attn.proj.weight",
|
||||||
|
"attn.to_add_out.bias": "txt_attn.proj.bias",
|
||||||
|
"ff.net.0.proj.weight": "img_mlp.0.weight",
|
||||||
|
"ff.net.0.proj.bias": "img_mlp.0.bias",
|
||||||
|
"ff.net.2.weight": "img_mlp.2.weight",
|
||||||
|
"ff.net.2.bias": "img_mlp.2.bias",
|
||||||
|
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
|
||||||
|
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
|
||||||
|
"ff_context.net.2.weight": "txt_mlp.2.weight",
|
||||||
|
"ff_context.net.2.bias": "txt_mlp.2.bias",
|
||||||
|
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
|
||||||
|
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
|
||||||
|
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
@@ -474,15 +493,41 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
|||||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||||
key_map["{}proj_mlp.{}".format(k, end)] = (qkv, (0, hidden_size * 3, hidden_size))
|
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
|
||||||
|
|
||||||
block_map = {#TODO
|
block_map = {
|
||||||
|
"norm.linear.weight": "modulation.lin.weight",
|
||||||
|
"norm.linear.bias": "modulation.lin.bias",
|
||||||
|
"proj_out.weight": "linear2.weight",
|
||||||
|
"proj_out.bias": "linear2.bias",
|
||||||
|
"attn.norm_q.weight": "norm.query_norm.scale",
|
||||||
|
"attn.norm_k.weight": "norm.key_norm.scale",
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in block_map:
|
for k in block_map:
|
||||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
|
||||||
|
|
||||||
MAP_BASIC = { #TODO
|
MAP_BASIC = {
|
||||||
|
("final_layer.linear.bias", "proj_out.bias"),
|
||||||
|
("final_layer.linear.weight", "proj_out.weight"),
|
||||||
|
("img_in.bias", "x_embedder.bias"),
|
||||||
|
("img_in.weight", "x_embedder.weight"),
|
||||||
|
("time_in.in_layer.bias", "time_text_embed.timestep_embedder.linear_1.bias"),
|
||||||
|
("time_in.in_layer.weight", "time_text_embed.timestep_embedder.linear_1.weight"),
|
||||||
|
("time_in.out_layer.bias", "time_text_embed.timestep_embedder.linear_2.bias"),
|
||||||
|
("time_in.out_layer.weight", "time_text_embed.timestep_embedder.linear_2.weight"),
|
||||||
|
("txt_in.bias", "context_embedder.bias"),
|
||||||
|
("txt_in.weight", "context_embedder.weight"),
|
||||||
|
("vector_in.in_layer.bias", "time_text_embed.text_embedder.linear_1.bias"),
|
||||||
|
("vector_in.in_layer.weight", "time_text_embed.text_embedder.linear_1.weight"),
|
||||||
|
("vector_in.out_layer.bias", "time_text_embed.text_embedder.linear_2.bias"),
|
||||||
|
("vector_in.out_layer.weight", "time_text_embed.text_embedder.linear_2.weight"),
|
||||||
|
("guidance_in.in_layer.bias", "time_text_embed.guidance_embedder.linear_1.bias"),
|
||||||
|
("guidance_in.in_layer.weight", "time_text_embed.guidance_embedder.linear_1.weight"),
|
||||||
|
("guidance_in.out_layer.bias", "time_text_embed.guidance_embedder.linear_2.bias"),
|
||||||
|
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
|
||||||
|
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
|
||||||
|
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
|
||||||
}
|
}
|
||||||
|
|
||||||
for k in MAP_BASIC:
|
for k in MAP_BASIC:
|
||||||
|
|||||||
Reference in New Issue
Block a user