It doesn't make sense for c_crossattn and c_concat to be lists.
This commit is contained in:
@@ -165,9 +165,9 @@ def sampling_function(model_function, x, timestep, uncond, cond, cond_scale, con
|
||||
c_crossattn_out.append(c)
|
||||
|
||||
if len(c_crossattn_out) > 0:
|
||||
out['c_crossattn'] = [torch.cat(c_crossattn_out)]
|
||||
out['c_crossattn'] = torch.cat(c_crossattn_out)
|
||||
if len(c_concat) > 0:
|
||||
out['c_concat'] = [torch.cat(c_concat)]
|
||||
out['c_concat'] = torch.cat(c_concat)
|
||||
if len(c_adm) > 0:
|
||||
out['c_adm'] = torch.cat(c_adm)
|
||||
return out
|
||||
|
||||
Reference in New Issue
Block a user