UNET weights can now be stored in fp8.
--fp8_e4m3fn-unet and --fp8_e5m2-unet are the two different formats supported by pytorch.
This commit is contained in:
@@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
||||
import comfy.model_management
|
||||
import comfy.conds
|
||||
from enum import Enum
|
||||
import contextlib
|
||||
from . import utils
|
||||
|
||||
class ModelType(Enum):
|
||||
@@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
context = c_crossattn
|
||||
dtype = self.get_dtype()
|
||||
|
||||
if comfy.model_management.supports_dtype(xc.device, dtype):
|
||||
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||
else:
|
||||
precision_scope = torch.autocast
|
||||
dtype = torch.float32
|
||||
|
||||
xc = xc.to(dtype)
|
||||
t = self.model_sampling.timestep(t).float()
|
||||
context = context.to(dtype)
|
||||
@@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module):
|
||||
if hasattr(extra, "to"):
|
||||
extra = extra.to(dtype)
|
||||
extra_conds[o] = extra
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||
|
||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||
|
||||
def get_dtype(self):
|
||||
|
||||
Reference in New Issue
Block a user