Refactor of sampler code to deal more easily with different model types.
This commit is contained in:
@@ -14,6 +14,7 @@ class DDIMSampler(object):
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
self.parameterization = kwargs.get("parameterization", "eps")
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
@@ -261,7 +262,7 @@ class DDIMSampler(object):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if denoise_function is not None:
|
||||
model_output = denoise_function(self.model.apply_model, x, t, **extra_args)
|
||||
model_output = denoise_function(x, t, **extra_args)
|
||||
elif unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
@@ -289,13 +290,13 @@ class DDIMSampler(object):
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == "v":
|
||||
if self.parameterization == "v":
|
||||
e_t = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * model_output + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == "eps", 'not implemented'
|
||||
assert self.parameterization == "eps", 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
@@ -309,7 +310,7 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != "v":
|
||||
if self.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * model_output
|
||||
|
||||
Reference in New Issue
Block a user