Support loading diffusers SD3 model format with UNETLoader node.

This commit is contained in:
comfyanonymous
2024-06-19 21:46:37 -04:00
parent b08a9dd04b
commit 0d6a57938e
4 changed files with 84 additions and 5 deletions

View File

@@ -568,7 +568,14 @@ def load_unet_state_dict(sd): #load unet in diffusers format
unet_dtype = model_management.unet_dtype(model_params=parameters)
load_device = model_management.get_torch_device()
if "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
if 'transformer_blocks.0.attn.add_q_proj.weight' in sd: #MMDIT SD3
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
if new_sd is None:
return None
model_config = model_detection.model_config_from_unet(new_sd, "")
if model_config is None:
return None
elif "input_blocks.0.0.weight" in sd or 'clf.1.weight' in sd: #ldm or stable cascade
model_config = model_detection.model_config_from_unet(sd, "")
if model_config is None:
return None