Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from . import kornia_functions
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
@@ -37,7 +38,7 @@ class ClassEmbedder(nn.Module):
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
@@ -57,18 +58,20 @@ def disabled_train(self, mode=True):
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77,
|
||||
freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@@ -92,6 +95,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"pooled",
|
||||
"hidden"
|
||||
]
|
||||
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
@@ -110,7 +114,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
@@ -118,7 +122,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden")
|
||||
outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
|
||||
if self.layer == "last":
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == "pooled":
|
||||
@@ -131,15 +135,55 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
return self(text)
|
||||
|
||||
|
||||
class ClipImageEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
jit=False,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
antialias=True,
|
||||
ucg_rate=0.
|
||||
):
|
||||
super().__init__()
|
||||
from clip import load as load_clip
|
||||
self.model, _ = load_clip(name=model, device=device, jit=jit)
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
# x = kornia_functions.geometry_resize(x, (224, 224),
|
||||
# interpolation='bicubic', align_corners=True,
|
||||
# antialias=self.antialias)
|
||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||
x = (x + 1.) / 2.
|
||||
# re-normalize according to clip
|
||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x, no_dropout=False):
|
||||
# x is assumed to be in range [-1,1]
|
||||
out = self.model.encode_image(self.preprocess(x))
|
||||
out = out.to(x.dtype)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
out = torch.bernoulli((1. - self.ucg_rate) * torch.ones(out.shape[0], device=out.device))[:, None] * out
|
||||
return out
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
#"pooled",
|
||||
# "pooled",
|
||||
"last",
|
||||
"penultimate"
|
||||
]
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="last"):
|
||||
super().__init__()
|
||||
@@ -179,7 +223,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask = None):
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
@@ -193,14 +237,73 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP vision transformer encoder for images
|
||||
"""
|
||||
|
||||
def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77,
|
||||
freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
|
||||
super().__init__()
|
||||
model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
|
||||
pretrained=version, )
|
||||
del model.transformer
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == "penultimate":
|
||||
raise NotImplementedError()
|
||||
self.layer_idx = 1
|
||||
|
||||
self.antialias = antialias
|
||||
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def preprocess(self, x):
|
||||
# normalize to [0,1]
|
||||
# x = kornia.geometry.resize(x, (224, 224),
|
||||
# interpolation='bicubic', align_corners=True,
|
||||
# antialias=self.antialias)
|
||||
x = torch.nn.functional.interpolate(x, size=(224, 224), mode='bicubic', align_corners=True, antialias=True)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia_functions.enhance_normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image, no_dropout=False):
|
||||
z = self.encode_with_vision_transformer(image)
|
||||
if self.ucg_rate > 0. and not no_dropout:
|
||||
z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
|
||||
return z
|
||||
|
||||
def encode_with_vision_transformer(self, img):
|
||||
img = self.preprocess(img)
|
||||
x = self.model.visual(img)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda",
|
||||
clip_max_length=77, t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.")
|
||||
print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
@@ -209,5 +312,3 @@ class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user