Add a TomePatchModel node to the _for_testing section.

Tome increases sampling speed at the expense of quality.
This commit is contained in:
comfyanonymous
2023-03-31 17:19:58 -04:00
parent 7e682784d7
commit 18a6c1db33
5 changed files with 166 additions and 11 deletions

View File

@@ -104,7 +104,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
out['c_concat'] = [torch.cat(c_concat)]
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in, model_options):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
@@ -195,7 +195,7 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat, model_options)
return uncond + (cond - uncond) * cond_scale
@@ -212,8 +212,8 @@ class CFGNoisePredictor(torch.nn.Module):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat, model_options=model_options)
return out
@@ -221,11 +221,11 @@ class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None):
def forward(self, x, sigma, uncond, cond, cond_scale, denoise_mask, cond_concat=None, model_options={}):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1))) * latent_mask
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat, model_options=model_options)
if denoise_mask is not None:
out *= denoise_mask
@@ -333,7 +333,7 @@ class KSampler:
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde",
"dpmpp_2m", "ddim", "uni_pc", "uni_pc_bh2"]
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None, model_options={}):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
@@ -353,6 +353,7 @@ class KSampler:
self.sigma_max=float(self.model_wrap.sigma_max)
self.set_steps(steps, denoise)
self.denoise = denoise
self.model_options = model_options
def _calculate_sigmas(self, steps):
sigmas = None
@@ -421,7 +422,7 @@ class KSampler:
else:
precision_scope = contextlib.nullcontext
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg}
extra_args = {"cond":positive, "uncond":negative, "cond_scale": cfg, "model_options": self.model_options}
cond_concat = None
if hasattr(self.model, 'concat_keys'):