Add a CONDConstant for passing non tensor conds to unet.

This commit is contained in:
comfyanonymous
2023-11-08 01:59:09 -05:00
parent 794dd2064d
commit 064d7583eb
2 changed files with 19 additions and 1 deletions

View File

@@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
out.append(c)
return torch.cat(out)
class CONDConstant(CONDRegular):
def __init__(self, cond):
self.cond = cond
def process_cond(self, batch_size, device, **kwargs):
return self._copy_with(self.cond)
def can_concat(self, other):
if self.cond != other.cond:
return False
return True
def concat(self, others):
return self.cond