Unify RMSNorm code.

This commit is contained in:
comfyanonymous
2024-08-28 16:18:39 -04:00
parent b79fd7d92c
commit d31e226650
3 changed files with 17 additions and 24 deletions

View File

@@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -63,8 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):