StableCascade CLIP model support.
This commit is contained in:
14
comfy/sd.py
14
comfy/sd.py
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
from enum import Enum
|
||||
|
||||
from comfy import model_management
|
||||
from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine
|
||||
@@ -309,8 +310,11 @@ def load_style_model(ckpt_path):
|
||||
model.load_state_dict(model_data)
|
||||
return StyleModel(model)
|
||||
|
||||
class CLIPType(Enum):
|
||||
STABLE_DIFFUSION = 1
|
||||
STABLE_CASCADE = 2
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
@@ -326,8 +330,12 @@ def load_clip(ckpt_paths, embedding_directory=None):
|
||||
clip_target.params = {}
|
||||
if len(clip_data) == 1:
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
if clip_type == CLIPType.STABLE_CASCADE:
|
||||
clip_target.clip = sdxl_clip.StableCascadeClipModel
|
||||
clip_target.tokenizer = sdxl_clip.StableCascadeTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
elif "text_model.encoder.layers.22.mlp.fc1.weight" in clip_data[0]:
|
||||
clip_target.clip = sd2_clip.SD2ClipModel
|
||||
clip_target.tokenizer = sd2_clip.SD2Tokenizer
|
||||
|
||||
Reference in New Issue
Block a user