Add support for unCLIP SD2.x models.
See _for_testing/unclip in the UI for the new nodes. unCLIPCheckpointLoader is used to load them. unCLIPConditioning is used to add the image cond and takes as input a CLIPVisionEncode output which has been moved to the conditioning section.
This commit is contained in:
@@ -1801,3 +1801,75 @@ class LatentUpscaleFinetuneDiffusion(LatentFinetuneDiffusion):
|
||||
log = super().log_images(*args, **kwargs)
|
||||
log["lr"] = rearrange(args[0]["lr"], 'b h w c -> b c h w')
|
||||
return log
|
||||
|
||||
|
||||
class ImageEmbeddingConditionedLatentDiffusion(LatentDiffusion):
|
||||
def __init__(self, embedder_config=None, embedding_key="jpg", embedding_dropout=0.5,
|
||||
freeze_embedder=True, noise_aug_config=None, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.embed_key = embedding_key
|
||||
self.embedding_dropout = embedding_dropout
|
||||
# self._init_embedder(embedder_config, freeze_embedder)
|
||||
self._init_noise_aug(noise_aug_config)
|
||||
|
||||
def _init_embedder(self, config, freeze=True):
|
||||
embedder = instantiate_from_config(config)
|
||||
if freeze:
|
||||
self.embedder = embedder.eval()
|
||||
self.embedder.train = disabled_train
|
||||
for param in self.embedder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def _init_noise_aug(self, config):
|
||||
if config is not None:
|
||||
# use the KARLO schedule for noise augmentation on CLIP image embeddings
|
||||
noise_augmentor = instantiate_from_config(config)
|
||||
assert isinstance(noise_augmentor, nn.Module)
|
||||
noise_augmentor = noise_augmentor.eval()
|
||||
noise_augmentor.train = disabled_train
|
||||
self.noise_augmentor = noise_augmentor
|
||||
else:
|
||||
self.noise_augmentor = None
|
||||
|
||||
def get_input(self, batch, k, cond_key=None, bs=None, **kwargs):
|
||||
outputs = LatentDiffusion.get_input(self, batch, k, bs=bs, **kwargs)
|
||||
z, c = outputs[0], outputs[1]
|
||||
img = batch[self.embed_key][:bs]
|
||||
img = rearrange(img, 'b h w c -> b c h w')
|
||||
c_adm = self.embedder(img)
|
||||
if self.noise_augmentor is not None:
|
||||
c_adm, noise_level_emb = self.noise_augmentor(c_adm)
|
||||
# assume this gives embeddings of noise levels
|
||||
c_adm = torch.cat((c_adm, noise_level_emb), 1)
|
||||
if self.training:
|
||||
c_adm = torch.bernoulli((1. - self.embedding_dropout) * torch.ones(c_adm.shape[0],
|
||||
device=c_adm.device)[:, None]) * c_adm
|
||||
all_conds = {"c_crossattn": [c], "c_adm": c_adm}
|
||||
noutputs = [z, all_conds]
|
||||
noutputs.extend(outputs[2:])
|
||||
return noutputs
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, N=8, n_row=4, **kwargs):
|
||||
log = dict()
|
||||
z, c, x, xrec, xc = self.get_input(batch, self.first_stage_key, bs=N, return_first_stage_outputs=True,
|
||||
return_original_cond=True)
|
||||
log["inputs"] = x
|
||||
log["reconstruction"] = xrec
|
||||
assert self.model.conditioning_key is not None
|
||||
assert self.cond_stage_key in ["caption", "txt"]
|
||||
xc = log_txt_as_img((x.shape[2], x.shape[3]), batch[self.cond_stage_key], size=x.shape[2] // 25)
|
||||
log["conditioning"] = xc
|
||||
uc = self.get_unconditional_conditioning(N, kwargs.get('unconditional_guidance_label', ''))
|
||||
unconditional_guidance_scale = kwargs.get('unconditional_guidance_scale', 5.)
|
||||
|
||||
uc_ = {"c_crossattn": [uc], "c_adm": c["c_adm"]}
|
||||
ema_scope = self.ema_scope if kwargs.get('use_ema_scope', True) else nullcontext
|
||||
with ema_scope(f"Sampling"):
|
||||
samples_cfg, _ = self.sample_log(cond=c, batch_size=N, ddim=True,
|
||||
ddim_steps=kwargs.get('ddim_steps', 50), eta=kwargs.get('ddim_eta', 0.),
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_, )
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f"samplescfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
||||
return log
|
||||
|
||||
Reference in New Issue
Block a user