Source code for hfai.nn.modules.hf_layernorm
import torch
import torch.nn as nn
import time
from torch import Tensor, Size
from typing import Union, List, Tuple
import numbers
no_layernorm = False
try:
import hfai.hfcuda.layernorm as layernorm
except:
no_layernorm = True
class LayerNormFunc(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x, normalized_shape, gamma, beta, eps, elementwise_affine, training):
assert x.dtype == torch.float32, \
'hfai.nn.LayerNorm暂时只支持float32'
shapex = x.shape
index = x.device.index
x = x.contiguous()
if elementwise_affine is True:
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = 1
for size in normalized_shape:
hidden_size = hidden_size * size
if elementwise_affine is True:
y, x_mean, x_var = layernorm.forward(x.view((-1, hidden_size)), gamma, beta, eps, index)
else:
y, x_mean, x_var = layernorm.forward_without_gammabeta(x.view((-1, hidden_size)), eps, index)
if training:
ctx.hidden_size = hidden_size
ctx.elementwise_affine = elementwise_affine
ctx.normalized_shape = normalized_shape
ctx.save_for_backward(x, x_mean, x_var, gamma)
return y.view(shapex)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.autograd.function.once_differentiable
def backward(ctx, dy):
hidden_size = ctx.hidden_size
elementwise_affine = ctx.elementwise_affine
normalized_shape = ctx.normalized_shape
x, x_mean, x_var, gamma = ctx.saved_tensors
index = dy.device.index
dy = dy.contiguous()
if elementwise_affine is True:
dxmat, dgamma, dbeta = layernorm.backward(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
x_var, gamma, index)
dgamma = dgamma.view(normalized_shape)
dbeta = dbeta.view(normalized_shape)
else:
dxmat = \
layernorm.backward_without_gammabeta(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
x_var,
index)[0]
dx = dxmat.view(dy.shape)
if elementwise_affine is True:
return dx, None, dgamma, dbeta, None, None, None
else:
return dx, None, None, None, None, None, None
_shape_t = Union[int, List[int], Size]
[docs]class LayerNorm(nn.LayerNorm):
"""
更高效的LayerNorm算子
接口和 `PyTorch的LayerNorm算子 <https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?highlight=layernorm#torch.nn.LayerNorm>`_ 一致
"""
def forward(self, input):
if not input.is_cuda:
return super().forward(input)
return LayerNormFunc.apply(input, self.normalized_shape, self.weight, self.bias, self.eps, self.elementwise_affine,
torch.is_grad_enabled() and self.training)