Update ldm dir with latest upstream stable diffusion changes.

This commit is contained in:
comfyanonymous
2023-02-09 13:47:36 -05:00
parent 642516a3a6
commit 1f6a467e92
5 changed files with 21 additions and 10 deletions

View File

@@ -11,16 +11,17 @@ MODEL_TYPES = {
class DPMSolverSampler(object):
def __init__(self, model, **kwargs):
def __init__(self, model, device=torch.device("cuda"), **kwargs):
super().__init__()
self.model = model
self.device = device
to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device("cuda"):
attr = attr.to(torch.device("cuda"))
if attr.device != self.device:
attr = attr.to(self.device)
setattr(self, name, attr)
@torch.no_grad()