Update ldm dir with latest upstream stable diffusion changes.
This commit is contained in:
@@ -11,16 +11,17 @@ MODEL_TYPES = {
|
||||
|
||||
|
||||
class DPMSolverSampler(object):
|
||||
def __init__(self, model, **kwargs):
|
||||
def __init__(self, model, device=torch.device("cuda"), **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.device = device
|
||||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
|
||||
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device("cuda"):
|
||||
attr = attr.to(torch.device("cuda"))
|
||||
if attr.device != self.device:
|
||||
attr = attr.to(self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
Reference in New Issue
Block a user