StableCascade CLIP model support.

This commit is contained in:
comfyanonymous
2024-02-16 13:29:04 -05:00
parent 667c92814e
commit 97d03ae04a
5 changed files with 43 additions and 8 deletions

View File

@@ -1,4 +1,5 @@
import torch
from enum import Enum
from comfy import model_management
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
@@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
model.load_state_dict(model_data)
return StyleModel(model)
class CLIPType(Enum):
STABLE_DIFFUSION = 1
STABLE_CASCADE = 2
def load_clip(ckpt_paths, embedding_directory=None):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
clip_target.params = {}
if len(clip_data) == 1:
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
if clip_type == CLIPType.STABLE_CASCADE:
clip_target.clip = sdxl_clip.StableCascadeClipModel
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
clip_target.clip = sd2_clip.SD2ClipModel
clip_target.tokenizer = sd2_clip.SD2Tokenizer