SD3 Support.
This commit is contained in:
@@ -33,6 +33,19 @@ class EDM(V_PREDICTION):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
|
||||
|
||||
class CONST:
|
||||
def calculate_input(self, sigma, noise):
|
||||
return noise
|
||||
|
||||
def calculate_denoised(self, sigma, model_output, model_input):
|
||||
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
|
||||
return model_input - model_output * sigma
|
||||
|
||||
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
|
||||
return sigma * noise + (1.0 - sigma) * latent_image
|
||||
|
||||
def inverse_noise_scaling(self, sigma, latent):
|
||||
return latent / (1.0 - sigma)
|
||||
|
||||
class ModelSamplingDiscrete(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
@@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
percent = 1.0 - percent
|
||||
return self.sigma(torch.tensor(percent * 999.0)).item()
|
||||
|
||||
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
|
||||
def timestep(self, sigma):
|
||||
return 0.25 * sigma.log()
|
||||
|
||||
def sigma(self, timestep):
|
||||
return (timestep / 0.25).exp()
|
||||
|
||||
class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
@@ -149,6 +168,48 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
|
||||
log_sigma_min = math.log(self.sigma_min)
|
||||
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
|
||||
|
||||
|
||||
def time_snr_shift(alpha, t):
|
||||
if alpha == 1.0:
|
||||
return t
|
||||
return alpha * t / (1 + (alpha - 1) * t)
|
||||
|
||||
class ModelSamplingDiscreteFlow(torch.nn.Module):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
if model_config is not None:
|
||||
sampling_settings = model_config.sampling_settings
|
||||
else:
|
||||
sampling_settings = {}
|
||||
|
||||
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
|
||||
|
||||
def set_parameters(self, shift=1.0, timesteps=1000):
|
||||
self.shift = shift
|
||||
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
|
||||
self.register_buffer('sigmas', ts)
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
||||
@property
|
||||
def sigma_max(self):
|
||||
return self.sigmas[-1]
|
||||
|
||||
def timestep(self, sigma):
|
||||
return sigma * 1000
|
||||
|
||||
def sigma(self, timestep):
|
||||
return time_snr_shift(self.shift, timestep / 1000)
|
||||
|
||||
def percent_to_sigma(self, percent):
|
||||
if percent <= 0.0:
|
||||
return 1.0
|
||||
if percent >= 1.0:
|
||||
return 0.0
|
||||
return 1.0 - percent
|
||||
|
||||
class StableCascadeSampling(ModelSamplingDiscrete):
|
||||
def __init__(self, model_config=None):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user