Source code for hfai.nn.modules.hf_norm

import torch
import torch.nn as nn
import time
from torch import Tensor, Size
from typing import Union, List, Tuple
import numbers

no_layernorm = False

    import hfai.hfcuda.layernorm as layernorm
    no_layernorm = True

    import hfai.hfcuda.rmsnorm as rmsnorm
    no_rmsnorm = True

class LayerNormFunc(torch.autograd.Function):
    def forward(ctx, x, normalized_shape, gamma, beta, eps, elementwise_affine, training):

        assert x.dtype == torch.float32, \

        shapex = x.shape
        index = x.device.index
        x = x.contiguous()

        if elementwise_affine is True:
            gamma = gamma.contiguous()
            beta = beta.contiguous()
        hidden_size = 1

        for size in normalized_shape:
            hidden_size = hidden_size * size

        if elementwise_affine is True:
            y, x_mean, x_var = layernorm.forward(x.view((-1, hidden_size)), gamma, beta, eps, index)
            y, x_mean, x_var = layernorm.forward_without_gammabeta(x.view((-1, hidden_size)), eps, index)

        if training:
            ctx.hidden_size = hidden_size
            ctx.elementwise_affine = elementwise_affine
            ctx.normalized_shape = normalized_shape
            ctx.save_for_backward(x, x_mean, x_var, gamma)

        return y.view(shapex)

    def backward(ctx, dy):
        hidden_size = ctx.hidden_size
        elementwise_affine = ctx.elementwise_affine
        normalized_shape = ctx.normalized_shape
        x, x_mean, x_var, gamma = ctx.saved_tensors

        index = dy.device.index
        dy = dy.contiguous()
        if elementwise_affine is True:
            dxmat, dgamma, dbeta = layernorm.backward(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
                                                      x_var, gamma, index)
            dgamma = dgamma.view(normalized_shape)
            dbeta = dbeta.view(normalized_shape)
            dxmat = \
                layernorm.backward_without_gammabeta(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,

        dx = dxmat.view(dy.shape)

        if elementwise_affine is True:
            return dx, None, dgamma, dbeta, None, None, None
            return dx, None, None, None, None, None, None

_shape_t = Union[int, List[int], Size]

[docs]class LayerNorm(nn.LayerNorm): """ 更高效的LayerNorm算子 接口和 `PyTorch的LayerNorm算子 <>`_ 一致 """ def forward(self, input): if not input.is_cuda: return super().forward(input) return LayerNormFunc.apply(input, self.normalized_shape, self.weight, self.bias, self.eps, self.elementwise_affine, torch.is_grad_enabled() and
# Reference implementation from Huggingface def manual_rms_norm(input, normalized_shape, weight, eps): # layer norm should always be calculated in float32 dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance =, keepdim=True) input = input * torch.rsqrt(variance + eps) if weight is None: return input # convert into half-precision if necessary if weight.dtype in [torch.float16, torch.bfloat16]: input = return weight * input class RMSNormFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, normalized_shape, weight, eps, elementwise_affine, is_training): hidden_size = 1 for size in normalized_shape: hidden_size *= size if elementwise_affine: weight = weight.contiguous() else: weight = torch.tensor([]) input = input.contiguous() output, invvar = rmsnorm.forward(input.view(-1, hidden_size), weight, eps, hidden_size) if is_training: ctx.save_for_backward(input, weight, invvar) ctx.hidden_size = hidden_size ctx.eps = eps return output.view(input.shape) @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, grad_output): hidden_size = ctx.hidden_size eps = ctx.eps input, weight, invvar = ctx.saved_tensors grad_output = grad_output.contiguous() grad_input, grad_weight = rmsnorm.backward(grad_output.view(-1, hidden_size), input.view(-1, hidden_size), weight, invvar, eps, hidden_size) grad_input = grad_input.view(grad_output.shape) return grad_input, None, grad_weight, None, None, None class RMSNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): super().__init__() if isinstance(normalized_shape, numbers.Integral): normalized_shape = (normalized_shape,) self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.ones(*normalized_shape)) else: self.register_parameter("weight", None) def forward(self, input): if not input.is_cuda: return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) return RMSNormFunc.apply(input, self.normalized_shape, self.weight, self.eps, self.elementwise_affine, torch.is_grad_enabled() and