SD3 Support.

This commit is contained in:
comfyanonymous
2024-06-10 13:26:25 -04:00
parent a5e6a632f9
commit 8c4a9befa7
17 changed files with 132182 additions and 5 deletions

View File

@@ -33,6 +33,19 @@ class EDM(V_PREDICTION):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5
class CONST:
def calculate_input(self, sigma, noise):
return noise
def calculate_denoised(self, sigma, model_output, model_input):
sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1))
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
return latent / (1.0 - sigma)
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None):
@@ -104,6 +117,12 @@ class ModelSamplingDiscrete(torch.nn.Module):
percent = 1.0 - percent
return self.sigma(torch.tensor(percent * 999.0)).item()
class ModelSamplingDiscreteEDM(ModelSamplingDiscrete):
def timestep(self, sigma):
return 0.25 * sigma.log()
def sigma(self, timestep):
return (timestep / 0.25).exp()
class ModelSamplingContinuousEDM(torch.nn.Module):
def __init__(self, model_config=None):
@@ -149,6 +168,48 @@ class ModelSamplingContinuousEDM(torch.nn.Module):
log_sigma_min = math.log(self.sigma_min)
return math.exp((math.log(self.sigma_max) - log_sigma_min) * percent + log_sigma_min)
def time_snr_shift(alpha, t):
if alpha == 1.0:
return t
return alpha * t / (1 + (alpha - 1) * t)
class ModelSamplingDiscreteFlow(torch.nn.Module):
def __init__(self, model_config=None):
super().__init__()
if model_config is not None:
sampling_settings = model_config.sampling_settings
else:
sampling_settings = {}
self.set_parameters(shift=sampling_settings.get("shift", 1.0))
def set_parameters(self, shift=1.0, timesteps=1000):
self.shift = shift
ts = self.sigma(torch.arange(1, timesteps + 1, 1))
self.register_buffer('sigmas', ts)
@property
def sigma_min(self):
return self.sigmas[0]
@property
def sigma_max(self):
return self.sigmas[-1]
def timestep(self, sigma):
return sigma * 1000
def sigma(self, timestep):
return time_snr_shift(self.shift, timestep / 1000)
def percent_to_sigma(self, percent):
if percent <= 0.0:
return 1.0
if percent >= 1.0:
return 0.0
return 1.0 - percent
class StableCascadeSampling(ModelSamplingDiscrete):
def __init__(self, model_config=None):
super().__init__()