import torch
import torch.nn as nn
import time
from torch import Tensor, Size
from typing import Union, List, Tuple
import numbers
no_layernorm = False
try:
import hfai.hfcuda.layernorm as layernorm
except:
no_layernorm = True
try:
import hfai.hfcuda.rmsnorm as rmsnorm
except:
no_rmsnorm = True
class LayerNormFunc(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, x, normalized_shape, gamma, beta, eps, elementwise_affine, training):
assert x.dtype == torch.float32, \
'hfai.nn.LayerNorm暂时只支持float32'
shapex = x.shape
index = x.device.index
x = x.contiguous()
if elementwise_affine is True:
gamma = gamma.contiguous()
beta = beta.contiguous()
hidden_size = 1
for size in normalized_shape:
hidden_size = hidden_size * size
if elementwise_affine is True:
y, x_mean, x_var = layernorm.forward(x.view((-1, hidden_size)), gamma, beta, eps, index)
else:
y, x_mean, x_var = layernorm.forward_without_gammabeta(x.view((-1, hidden_size)), eps, index)
if training:
ctx.hidden_size = hidden_size
ctx.elementwise_affine = elementwise_affine
ctx.normalized_shape = normalized_shape
ctx.save_for_backward(x, x_mean, x_var, gamma)
return y.view(shapex)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.autograd.function.once_differentiable
def backward(ctx, dy):
hidden_size = ctx.hidden_size
elementwise_affine = ctx.elementwise_affine
normalized_shape = ctx.normalized_shape
x, x_mean, x_var, gamma = ctx.saved_tensors
index = dy.device.index
dy = dy.contiguous()
if elementwise_affine is True:
dxmat, dgamma, dbeta = layernorm.backward(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
x_var, gamma, index)
dgamma = dgamma.view(normalized_shape)
dbeta = dbeta.view(normalized_shape)
else:
dxmat = \
layernorm.backward_without_gammabeta(dy.view((-1, hidden_size)), x.view((-1, hidden_size)), x_mean,
x_var,
index)[0]
dx = dxmat.view(dy.shape)
if elementwise_affine is True:
return dx, None, dgamma, dbeta, None, None, None
else:
return dx, None, None, None, None, None, None
_shape_t = Union[int, List[int], Size]
[docs]class LayerNorm(nn.LayerNorm):
"""
更高效的LayerNorm算子
接口和 `PyTorch的LayerNorm算子 <https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html?highlight=layernorm#torch.nn.LayerNorm>`_ 一致
"""
def forward(self, input):
if not input.is_cuda:
return super().forward(input)
return LayerNormFunc.apply(input, self.normalized_shape, self.weight, self.bias, self.eps,
self.elementwise_affine,
torch.is_grad_enabled() and self.training)
# Reference implementation from Huggingface
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
return weight * input
class RMSNormFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, normalized_shape, weight, eps, elementwise_affine, is_training):
hidden_size = 1
for size in normalized_shape:
hidden_size *= size
if elementwise_affine:
weight = weight.contiguous()
else:
weight = torch.tensor([])
input = input.contiguous()
output, invvar = rmsnorm.forward(input.view(-1, hidden_size), weight, eps, hidden_size)
if is_training:
ctx.save_for_backward(input, weight, invvar)
ctx.hidden_size = hidden_size
ctx.eps = eps
return output.view(input.shape)
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
hidden_size = ctx.hidden_size
eps = ctx.eps
input, weight, invvar = ctx.saved_tensors
grad_output = grad_output.contiguous()
grad_input, grad_weight = rmsnorm.backward(grad_output.view(-1, hidden_size), input.view(-1, hidden_size),
weight, invvar, eps, hidden_size)
grad_input = grad_input.view(grad_output.shape)
return grad_input, None, grad_weight, None, None, None
class RMSNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.ones(*normalized_shape))
else:
self.register_parameter("weight", None)
def forward(self, input):
if not input.is_cuda:
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps)
return RMSNormFunc.apply(input, self.normalized_shape, self.weight, self.eps, self.elementwise_affine,
torch.is_grad_enabled() and self.training)