Shortcuts

Source code for hfai.nn.modules.hf_layernorm

import torch
import torch.nn as nn
import time
from torch import Tensor, Size
from typing import Union, List, Tuple
import numbers

no_layernorm = False

try:
    import hfai.hfcuda.layernorm as layernorm
except:
    no_layernorm = True


class LayerNormFunc(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x, normalized_shape, gamma, beta, eps, elementwise_affine, training):

        assert x.dtype == torch.float32, \
            'hfai.nn.LayerNorm暂时只支持float32'

        shapex = x.shape
        index = x.device.index
        x = x.contiguous()

        if elementwise_affine is True:
            gamma = gamma.contiguous()
            beta = beta.contiguous()
        hidden_size = 1

        for size in normalized_shape:
            hidden_size = hidden_size * size

        if elementwise_affine is True:
            y, x_mean, x_var = layernorm.forward(x.view((-1, hidden_size)), gamma, beta, eps, index)
        else:
            y, x_mean, x_var = layernorm.forward_without_gammabeta(x.view((-1, hidden_size)), eps, index)

        if training:
            ctx.hidden_size = hidden_size
            ctx.elementwise_affine = elementwise_affine
            ctx.normalized_shape = normalized_shape
            ctx.save_for_backward(x, x_mean, x_var, gamma)

        return y.view(shapex)

    @staticmethod
    @torch.cuda.amp.custom_bwd
    @torch.autograd.function.once_differentiable
    def backward(ctx, dy):
        hidden_size = ctx.hidden_size
        elementwise_affine = ctx.elementwise_affine
        normalized_shape = ctx.normalized_shape
        x, x_mean, x_var, gamma = ctx.saved_tensors

        index = dy.device.index
        dy = dy.contiguous()
        if elementwise_affine is True:
            dxmat, dgamma, dbeta = layernorm.backward(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
                                                      x_var, gamma, index)
            dgamma = dgamma.view(normalized_shape)
            dbeta = dbeta.view(normalized_shape)
        else:
            dxmat = \
                layernorm.backward_without_gammabeta(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
                                                     x_var,
                                                     index)[0]

        dx = dxmat.view(dy.shape)

        if elementwise_affine is True:
            return dx, None, dgamma, dbeta, None, None, None
        else:
            return dx, None, None, None, None, None, None


_shape_t = Union[int, List[int], Size]


[docs]class LayerNorm(nn.LayerNorm): """ 更高效的LayerNorm算子 接口和 `PyTorch的LayerNorm算子 <https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?highlight=layernorm#torch.nn.LayerNorm>`_ 一致 """ def forward(self, input): if not input.is_cuda: return super().forward(input) return LayerNormFunc.apply(input, self.normalized_shape, self.weight, self.bias, self.eps, self.elementwise_affine, torch.is_grad_enabled() and self.training)