Add a CONDConstant for passing non tensor conds to unet.
This commit is contained in:
@@ -62,3 +62,18 @@ class CONDCrossAttn(CONDRegular):
|
||||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
|
||||
out.append(c)
|
||||
return torch.cat(out)
|
||||
|
||||
class CONDConstant(CONDRegular):
|
||||
def __init__(self, cond):
|
||||
self.cond = cond
|
||||
|
||||
def process_cond(self, batch_size, device, **kwargs):
|
||||
return self._copy_with(self.cond)
|
||||
|
||||
def can_concat(self, other):
|
||||
if self.cond != other.cond:
|
||||
return False
|
||||
return True
|
||||
|
||||
def concat(self, others):
|
||||
return self.cond
|
||||
|
||||
Reference in New Issue
Block a user