Add ControlNet support.

This commit is contained in:
comfyanonymous
2023-02-16 10:38:08 -05:00
parent bc69fb5245
commit 4efa67fa12
9 changed files with 580 additions and 63 deletions

View File

@@ -21,12 +21,13 @@ class CFGDenoiser(torch.nn.Module):
uncond = self.inner_model(x, sigma, cond=uncond)
return uncond + (cond - uncond) * cond_scale
def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in, cond_concat_in):
#The main sampling function shared by all the samplers
#Returns predicted noise
def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, cond_concat=None):
def get_area_and_mult(cond, x_in, cond_concat_in, timestep_in):
area = (x_in.shape[2], x_in.shape[3], 0, 0)
strength = 1.0
min_sigma = 0.0
max_sigma = 999.0
if 'area' in cond[1]:
area = cond[1]['area']
if 'strength' in cond[1]:
@@ -56,9 +57,15 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
cr = x[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
cropped.append(cr)
conditionning['c_concat'] = torch.cat(cropped, dim=1)
return (input_x, mult, conditionning, area)
control = None
if 'control' in cond[1]:
control = cond[1]['control']
return (input_x, mult, conditionning, area, control)
def cond_equal_size(c1, c2):
if c1 is c2:
return True
if c1.keys() != c2.keys():
return False
if 'c_crossattn' in c1:
@@ -69,6 +76,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
return False
return True
def can_concat_cond(c1, c2):
if c1[0].shape != c2[0].shape:
return False
if (c1[4] is None) != (c2[4] is None):
return False
if c1[4] is not None:
if c1[4] is not c2[4]:
return False
return cond_equal_size(c1[2], c2[2])
def cond_cat(c_list):
c_crossattn = []
c_concat = []
@@ -84,7 +102,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
out['c_concat'] = [torch.cat(c_concat)]
return out
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, sigma, max_total_area, cond_concat_in):
def calc_cond_uncond_batch(model_function, cond, uncond, x_in, timestep, max_total_area, cond_concat_in):
out_cond = torch.zeros_like(x_in)
out_count = torch.ones_like(x_in)/100000.0
@@ -96,13 +114,13 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
to_run = []
for x in cond:
p = get_area_and_mult(x, x_in, cond_concat_in)
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
to_run += [(p, COND)]
for x in uncond:
p = get_area_and_mult(x, x_in, cond_concat_in)
p = get_area_and_mult(x, x_in, cond_concat_in, timestep)
if p is None:
continue
@@ -113,9 +131,8 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
first_shape = first[0][0].shape
to_batch_temp = []
for x in range(len(to_run)):
if to_run[x][0][0].shape == first_shape:
if cond_equal_size(to_run[x][0][2], first[0][2]):
to_batch_temp += [x]
if can_concat_cond(to_run[x][0], first[0]):
to_batch_temp += [x]
to_batch_temp.reverse()
to_batch = to_batch_temp[:1]
@@ -131,6 +148,7 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c = []
cond_or_uncond = []
area = []
control = None
for x in to_batch:
o = to_run.pop(x)
p = o[0]
@@ -139,13 +157,17 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
c += [p[2]]
area += [p[3]]
cond_or_uncond += [o[1]]
control = p[4]
batch_chunks = len(cond_or_uncond)
input_x = torch.cat(input_x)
c = cond_cat(c)
sigma_ = torch.cat([sigma] * batch_chunks)
timestep_ = torch.cat([timestep] * batch_chunks)
output = model_function(input_x, sigma_, cond=c).chunk(batch_chunks)
if control is not None:
c['control'] = control.get_control(input_x, timestep_, c['c_crossattn'])
output = model_function(input_x, timestep_, cond=c).chunk(batch_chunks)
del input_x
for o in range(batch_chunks):
@@ -166,10 +188,29 @@ def sampling_function(model_function, x, sigma, uncond, cond, cond_scale, cond_c
max_total_area = model_management.maximum_batch_area()
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, sigma, max_total_area, cond_concat)
cond, uncond = calc_cond_uncond_batch(model_function, cond, uncond, x, timestep, max_total_area, cond_concat)
return uncond + (cond - uncond) * cond_scale
class CFGDenoiserComplex(torch.nn.Module):
class CompVisVDenoiser(k_diffusion_external.DiscreteVDDPMDenoiser):
def __init__(self, model, quantize=False, device='cpu'):
super().__init__(model, model.alphas_cumprod, quantize=quantize)
def get_v(self, x, t, cond, **kwargs):
return self.inner_model.apply_model(x, t, cond, **kwargs)
class CFGNoisePredictor(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
self.alphas_cumprod = model.alphas_cumprod
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None):
out = sampling_function(self.inner_model.apply_model, x, timestep, uncond, cond, cond_scale, cond_concat)
return out
class KSamplerX0Inpaint(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.inner_model = model
@@ -177,7 +218,7 @@ class CFGDenoiserComplex(torch.nn.Module):
if denoise_mask is not None:
latent_mask = 1. - denoise_mask
x = x * denoise_mask + (self.latent_image + self.noise * sigma) * latent_mask
out = sampling_function(self.inner_model, x, sigma, uncond, cond, cond_scale, cond_concat)
out = self.inner_model(x, sigma, cond=cond, uncond=uncond, cond_scale=cond_scale, cond_concat=cond_concat)
if denoise_mask is not None:
out *= denoise_mask
@@ -196,8 +237,6 @@ def simple_scheduler(model, steps):
def blank_inpaint_image_like(latent_image):
blank_image = torch.ones_like(latent_image)
# these are the values for "zero" in pixel space translated to latent space
# the proper way to do this is to apply the mask to the image in pixel space and then send it through the VAE
# unfortunately that gives zero flexibility so I did things like this instead which hopefully works
blank_image[:,0] *= 0.8223
blank_image[:,1] *= -0.6876
blank_image[:,2] *= 0.6364
@@ -234,6 +273,42 @@ def create_cond_with_same_area_if_none(conds, c):
n = c[1].copy()
conds += [[smallest[0], n]]
def apply_control_net_to_equal_area(conds, uncond):
cond_cnets = []
cond_other = []
uncond_cnets = []
uncond_other = []
for t in range(len(conds)):
x = conds[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
cond_cnets.append(x[1]['control'])
else:
cond_other.append((x, t))
for t in range(len(uncond)):
x = uncond[t]
if 'area' not in x[1]:
if 'control' in x[1] and x[1]['control'] is not None:
uncond_cnets.append(x[1]['control'])
else:
uncond_other.append((x, t))
if len(uncond_cnets) > 0:
return
for x in range(len(cond_cnets)):
temp = uncond_other[x % len(uncond_other)]
o = temp[0]
if 'control' in o[1] and o[1]['control'] is not None:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond += [[o[0], n]]
else:
n = o[1].copy()
n['control'] = cond_cnets[x]
uncond[temp[1]] = [o[0], n]
class KSampler:
SCHEDULERS = ["karras", "normal", "simple"]
SAMPLERS = ["sample_euler", "sample_euler_ancestral", "sample_heun", "sample_dpm_2", "sample_dpm_2_ancestral",
@@ -242,11 +317,13 @@ class KSampler:
def __init__(self, model, steps, device, sampler=None, scheduler=None, denoise=None):
self.model = model
self.model_denoise = CFGNoisePredictor(self.model)
if self.model.parameterization == "v":
self.model_wrap = k_diffusion_external.CompVisVDenoiser(self.model, quantize=True)
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
else:
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model, quantize=True)
self.model_k = CFGDenoiserComplex(self.model_wrap)
self.model_wrap = k_diffusion_external.CompVisDenoiser(self.model_denoise, quantize=True)
self.model_wrap.parameterization = self.model.parameterization
self.model_k = KSamplerX0Inpaint(self.model_wrap)
self.device = device
if scheduler not in self.SCHEDULERS:
scheduler = self.SCHEDULERS[0]
@@ -316,6 +393,8 @@ class KSampler:
for c in negative:
create_cond_with_same_area_if_none(positive, c)
apply_control_net_to_equal_area(positive, negative)
if self.model.model.diffusion_model.dtype == torch.float16:
precision_scope = torch.autocast
else: