UNET weights can now be stored in fp8.

--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats
supported by pytorch.
This commit is contained in:
comfyanonymous
2023-12-04 11:10:00 -05:00
parent af365e4dd1
commit 31b0f6f3d8
6 changed files with 47 additions and 10 deletions

View File

@@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
import comfy.model_management
import comfy.conds
from enum import Enum
import contextlib
from . import utils
class ModelType(Enum):
@@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module):
context = c_crossattn
dtype = self.get_dtype()
if comfy.model_management.supports_dtype(xc.device, dtype):
precision_scope = lambda a: contextlib.nullcontext(a)
else:
precision_scope = torch.autocast
dtype = torch.float32
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
context = context.to(dtype)
@@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module):
if hasattr(extra, "to"):
extra = extra.to(dtype)
extra_conds[o] = extra
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
def get_dtype(self):