from typing import Tuple, Union, Optional
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_, kaiming_uniform_, _calculate_fan_in_and_fan_out, \
uniform_
import math
from .context import context
import torch.nn.functional as F
no_attention = False
no_attention_half = False
no_attention_compact = False
no_flash_attn_fp32 = False
no_flash_attn_tf32 = False
no_fused_attn = False
no_qkv_inference = False
try:
import hfai.hfcuda.nn_attention as attention
except:
no_attention = True
try:
import hfai.hfcuda.nn_attention_half as attention_half
except:
no_attention_half = True
try:
import hfai.hfcuda.nn_attention_compact as attention_compact
except:
no_attention_compact = True
try:
import hfai.hfcuda.flash_attention_fp32 as flash_attn_fp32
except:
no_flash_attn_fp32 = True
try:
import hfai.hfcuda.flash_attention_tf32 as flash_attn_tf32
except:
no_flash_attn_tf32 = True
try:
import hfai.hfcuda.fused_mha as fused_mha
except:
no_fused_mha = True
try:
import hfai.hfcuda.qkv_inference as qkv_infer
except:
no_qkv_inference = True
def get_attention():
if context.GetAttnAllowConversion():
return attention_half
else:
return attention
class MultiheadAttentionFunc(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def forward(ctx, query, key, value, key_padding_mask1, need_weights, attn_mask1,
embed_dim, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
add_zero_attn, dropout, out_proj_weight, out_proj_bias, _qkv_same_embed_dim, batch_first,
need_softmax_temperature,
scaleT, compact,
training):
assert bias_k is None, \
'hfai.nn.MultiheadAttention不支持bias_k'
assert bias_v is None, \
'hfai.nn.MultiheadAttention不支持bias_v'
assert add_zero_attn == False, \
'hfai.nn.MultiheadAttention不支持add_zero_attn'
assert query.dtype == torch.float32, \
'hfai.nn.MultiheadAttention暂时只支持float32'
same_qkv = False
if batch_first:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
same_qkv = True
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
else:
if key is value and query is key:
same_qkv = True
seq_len, batch_size, embed_dim = query.size()
seq_len_k, _, _ = key.size()
attn_mask = attn_mask1
key_padding_mask = key_padding_mask1
if attn_mask1 is None:
attn_mask1 = torch.tensor([])
if key_padding_mask1 is None:
key_padding_mask1 = torch.tensor([])
if _qkv_same_embed_dim:
can_use_fused = True
if (seq_len != 64) or (
embed_dim // num_heads != 64 and embed_dim // num_heads != 96 and embed_dim // num_heads != 128) or need_weights or (
torch.backends.cuda.matmul.allow_tf32 == False):
can_use_fused = False
if attn_mask is not None:
if attn_mask.dtype != torch.float32:
can_use_fused = False
if key_padding_mask is not None:
if key_padding_mask != torch.float32:
can_use_fused = False
if can_use_fused:
index = query.device.index
attn_mask_need_grad = False
key_mask_need_grad = False
if attn_mask1.numel() > 0:
if attn_mask1.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
attn_mask_need_grad = attn_mask1.requires_grad
if key_padding_mask1.numel() > 0:
key_mask_need_grad = key_padding_mask1.requires_grad
if attn_mask is None:
attn_mask = torch.tensor([])
if key_padding_mask is None:
key_padding_mask = torch.tensor([])
if training:
ctx.num_heads = num_heads
ctx.dropout = dropout
ctx.embed_dim = embed_dim
ctx.index = index
ctx.batch_first = batch_first
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.seq_len_k = seq_len_k
ctx.need_softmax_temperature = need_softmax_temperature
ctx.attn_mask_need_grad = attn_mask_need_grad
ctx.key_mask_need_grad = key_mask_need_grad
if same_qkv:
if num_heads != 1:
in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads, embed_dim) \
.transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads) \
.transpose(0, 1).contiguous().reshape(3 * embed_dim)
out_tensors, seeds = fused_mha.mha_forward_same_qkv(query.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(),
out_proj_bias.contiguous(),
key_padding_mask.contiguous(),
attn_mask.contiguous(),
num_heads, dropout, need_softmax_temperature,
scaleT, False, training)
output, qkv, output2, softmax_sum, softmax_max = out_tensors
ctx.attn_method = 5
ctx.seeds = seeds
ctx.save_for_backward(output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight,
key_padding_mask1, attn_mask1, scaleT)
else:
out_tensors, seeds = fused_mha.mha_forward(query.contiguous(), key.contiguous(), value.contiguous(),
in_proj_weight.contiguous(), in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(),
key_padding_mask.contiguous(), attn_mask.contiguous(),
num_heads, dropout, need_softmax_temperature, scaleT,
False, training)
output, q1, k1, v1, output2, softmax_sum, softmax_max = out_tensors
ctx.attn_method = 6
ctx.seeds = seeds
ctx.save_for_backward(output, query, key, value, q1, k1, v1, softmax_sum, softmax_max,
in_proj_weight, out_proj_weight, key_padding_mask1, attn_mask1, scaleT)
return output2, torch.ones(1, device=query.device)
can_use_flash = True
if torch.backends.cuda.matmul.allow_tf32 and embed_dim // num_heads > 128:
can_use_flash = False
if torch.backends.cuda.matmul.allow_tf32 == False and embed_dim // num_heads > 64:
can_use_flash = False
if need_weights or need_softmax_temperature:
can_use_flash = False
if key_padding_mask1.numel() > 0:
if (key_padding_mask1.shape != (batch_size, seq_len_k)) or (key_padding_mask1.requires_grad == True) \
or (key_padding_mask1.dtype != torch.float32):
can_use_flash = False
if attn_mask1.numel() > 0:
if (attn_mask1.dtype != torch.float32) or (attn_mask1.requires_grad == True):
can_use_flash = False
if can_use_flash:
run_attn = None
if torch.backends.cuda.matmul.allow_tf32:
run_attn = flash_attn_tf32
else:
run_attn = flash_attn_fp32
index = query.device.index
if attn_mask1.numel() > 0:
if attn_mask1.dim() == 2:
attn_mask1 = attn_mask1.unsqueeze(0)
if training:
ctx.num_heads = num_heads
ctx.dropout = dropout
ctx.embed_dim = embed_dim
ctx.index = index
ctx.batch_first = batch_first
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.seq_len_k = seq_len_k
ctx.need_softmax_temperature = need_softmax_temperature
ctx.run_attn = run_attn
ctx.attn_mask_need_grad = False
ctx.key_mask_need_grad = False
if same_qkv:
in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads, embed_dim) \
.transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads) \
.transpose(0, 1).contiguous().reshape(3 * embed_dim)
out_tensors, seeds = run_attn.mha_forward_same_qkv(query.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(),
out_proj_bias.contiguous(),
key_padding_mask1.contiguous(),
attn_mask1.contiguous(),
num_heads, dropout, False, training)
output, qkv, output2, softmax_sum, softmax_max = out_tensors
ctx.attn_method = 3
ctx.seeds = seeds
ctx.save_for_backward(output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight,
key_padding_mask1, attn_mask1)
else:
out_tensors, seeds = run_attn.mha_forward(query.contiguous(), key.contiguous(), value.contiguous(),
in_proj_weight.contiguous(), in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(),
key_padding_mask1.contiguous(), attn_mask1.contiguous(),
num_heads, dropout, False, training)
output, q1, k1, v1, output2, softmax_sum, softmax_max = out_tensors
ctx.attn_method = 4
ctx.seeds = seeds
ctx.save_for_backward(output, query, key, value, q1, k1, v1, softmax_sum, softmax_max,
in_proj_weight, out_proj_weight, key_padding_mask1, attn_mask1)
return output2, torch.ones(1, device=query.device)
attn_mask_need_grad = False
key_mask_need_grad = False
if attn_mask is not None:
assert attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool or attn_mask.dtype == torch.float32, \
'hfai.nn.MultiheadAttention的attn_mask只支持uint8,bool和float32, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dtype == torch.float32:
attn_mask_need_grad = attn_mask.requires_grad
ctx.attn_mask_shape = attn_mask.shape
if attn_mask.dim() == 2:
correct_2d_size = (seq_len, seq_len_k)
if attn_mask.shape != correct_2d_size:
raise RuntimeError(
f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size = (batch_size * num_heads, seq_len, seq_len_k)
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
if attn_mask.shape != correct_3d_size and attn_mask.shape != correct_3d_size1:
raise RuntimeError(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size} or {correct_3d_size1}.")
if attn_mask.shape == correct_3d_size1:
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
.reshape(batch_size * num_heads, seq_len, seq_len_k)
else:
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
if key_padding_mask is not None:
assert key_padding_mask.dtype == torch.uint8 or key_padding_mask.dtype == torch.bool or key_padding_mask.dtype == torch.float32, \
'hfai.nn.MultiheadAttention的key_padding_mask只支持uint8,bool和float32, not {}'.format(
key_padding_mask.dtype)
assert key_padding_mask.shape == (batch_size, seq_len_k), \
f"expecting key_padding_mask shape of {(batch_size, seq_len_k)}, but got {key_padding_mask.shape}"
if key_padding_mask.dtype == torch.uint8:
key_padding_mask = key_padding_mask.to(torch.bool)
if key_padding_mask.dtype == torch.float32:
key_mask_need_grad = key_padding_mask.requires_grad
key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k).reshape(-1,
seq_len_k).unsqueeze(
1)
if attn_mask is not None and key_padding_mask is not None:
assert attn_mask.dtype == key_padding_mask.dtype, "attn_mask 和 key_padding_mask 类型要一致"
if compact:
if training:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
output, attnPre, dropout_mask, attn_softmax, output2 = attention_compact.forward(
query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)
ctx.num_heads = num_heads
ctx.dropout = dropout
ctx.embed_dim = embed_dim
ctx.index = index
ctx.batch_first = batch_first
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.seq_len_k = seq_len_k
ctx.need_softmax_temperature = need_softmax_temperature
ctx.attn_mask_need_grad = attn_mask_need_grad
ctx.key_mask_need_grad = key_mask_need_grad
ctx.attn_method = 2
ctx.save_for_backward(output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax,
query,
key, value, in_proj_weight, in_proj_bias, attn_mask1, key_padding_mask1,
scaleT)
else:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
output2, attn = attention_compact.inference(
query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)
else:
if same_qkv:
if training:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
embed_dim).transpose(0, 1).contiguous().reshape(
3 * embed_dim, embed_dim)
in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
1).contiguous().reshape(
3 * embed_dim)
run_attn = get_attention()
output, qkv, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward(
query.contiguous(), in_proj_weight.contiguous(), in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(),
attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
)
ctx.run_attn = run_attn
ctx.num_heads = num_heads
ctx.embed_dim = embed_dim
ctx.dropout = dropout
ctx.batch_first = batch_first
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.seq_len_k = seq_len_k
ctx.index = index
ctx.need_softmax_temperature = need_softmax_temperature
ctx.attn_mask_need_grad = attn_mask_need_grad
ctx.key_mask_need_grad = key_mask_need_grad
ctx.attn_method = 0
ctx.save_for_backward(output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn,
query,
in_proj_weight, attn_mask1, key_padding_mask1, scaleT)
else:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
embed_dim).transpose(0, 1).contiguous().reshape(
3 * embed_dim, embed_dim)
in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
1).contiguous().reshape(
3 * embed_dim)
run_attn = get_attention()
output2, attn = run_attn.inference(query.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(), out_proj_weight.contiguous(),
out_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature,
scaleT.contiguous(), index)
else:
if training:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
# in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
# in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
run_attn = get_attention()
output, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward_diff_qkv(
query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
)
ctx.run_attn = run_attn
ctx.num_heads = num_heads
ctx.dropout = dropout
ctx.embed_dim = embed_dim
ctx.index = index
ctx.batch_first = batch_first
ctx.batch_size = batch_size
ctx.seq_len = seq_len
ctx.seq_len_k = seq_len_k
ctx.need_softmax_temperature = need_softmax_temperature
ctx.attn_mask_need_grad = attn_mask_need_grad
ctx.key_mask_need_grad = key_mask_need_grad
ctx.attn_method = 1
ctx.save_for_backward(output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax,
attn,
query, key, value, in_proj_weight, attn_mask1, key_padding_mask1, scaleT)
else:
index = query.device.index
if attn_mask is not None:
attn_mask = attn_mask.contiguous()
else:
attn_mask = torch.tensor([])
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.contiguous()
else:
key_padding_mask = torch.tensor([])
# in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
# in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
run_attn = get_attention()
output2, attn = run_attn.inference_diff_qkv(
query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(),
out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
)
if batch_first:
output2 = output2.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
if need_weights:
if compact and training:
# attn, _ = torch._fused_dropout(attn_softmax, p=(1.-dropout))
if dropout > 0.0001:
attn = torch._masked_scale(attn_softmax, dropout_mask, 1.0 / (1.0 - dropout))
else:
attn = attn_softmax
attn = attn.view(batch_size, num_heads, seq_len, seq_len_k)
return output2, attn.sum(dim=1) / num_heads
else:
return output2, torch.ones(1, device=index)
@staticmethod
@torch.cuda.amp.custom_bwd
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output2, grad_weight):
num_heads = ctx.num_heads
embed_dim = ctx.embed_dim
dropout = ctx.dropout
index = ctx.index
batch_first = ctx.batch_first
batch_size = ctx.batch_size
seq_len = ctx.seq_len
seq_len_k = ctx.seq_len_k
need_softmax_temperature = ctx.need_softmax_temperature
attn_mask_need_grad = ctx.attn_mask_need_grad
key_mask_need_grad = ctx.key_mask_need_grad
attn_method = ctx.attn_method
if batch_first:
grad_output2 = grad_output2.transpose(0, 1)
if attn_method == 0:
output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn, inputs, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
if attn_mask.numel() > 0:
if attn_mask.dtype == torch.uint8:
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dtype == torch.bool:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
if attn_mask.shape == correct_3d_size1:
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
.reshape(batch_size * num_heads, seq_len, seq_len_k)
attn_mask = attn_mask.contiguous()
if key_padding_mask.numel() > 0:
if key_padding_mask.dtype == torch.uint8:
key_padding_mask = key_padding_mask.to(torch.bool)
if key_padding_mask.dtype == torch.bool:
key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
.reshape(-1, seq_len_k).unsqueeze(1).contiguous()
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), inputs.contiguous(),
in_proj_weight.contiguous(),
attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature, scaleT.contiguous(),
attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
)
grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
embed_dim).transpose(0, 1).contiguous().view(
3 * embed_dim, embed_dim)
grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
1).contiguous().view(
3 * embed_dim)
grad_attn_mask = None
if attn_mask_need_grad:
attn_mask_shape = ctx.attn_mask_shape
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
grad_attn_mask = grad_mask
if attn_mask_shape == correct_3d_size1:
grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
grad_key_mask = None
if key_mask_need_grad:
grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
if batch_first:
grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
if need_softmax_temperature:
return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
else:
return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None
elif attn_method == 1:
output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, query, key, value, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
if attn_mask.numel() > 0:
if attn_mask.dtype == torch.uint8:
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dtype == torch.bool:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
if attn_mask.shape == correct_3d_size1:
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
.reshape(batch_size * num_heads, seq_len, seq_len_k)
attn_mask = attn_mask.contiguous()
if key_padding_mask.numel() > 0:
if key_padding_mask.dtype == torch.uint8:
key_padding_mask = key_padding_mask.to(torch.bool)
if key_padding_mask.dtype == torch.bool:
key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
.reshape(-1, seq_len_k).unsqueeze(1).contiguous()
grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward_diff_qkv(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
k1.contiguous(), v1.contiguous(),
dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), query.contiguous(),
key.contiguous(), value.contiguous(),
in_proj_weight.contiguous(), attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature,
scaleT.contiguous(), attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
)
# grad_in_proj_weight = grad_in_proj_weight.reshape(3, num_heads, embed_dim//num_heads, embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * embed_dim, embed_dim)
# grad_in_proj_bias = grad_in_proj_bias.reshape(3, num_heads, embed_dim//num_heads).permute(1, 0, 2).contiguous().view(3 * embed_dim)
grad_attn_mask = None
if attn_mask_need_grad:
attn_mask_shape = ctx.attn_mask_shape
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
grad_attn_mask = grad_mask
if attn_mask_shape == correct_3d_size1:
grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
grad_key_mask = None
if key_mask_need_grad:
grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
if batch_first:
grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
if need_softmax_temperature:
return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
else:
return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None
elif attn_method == 2:
output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax, query, key, value, in_proj_weight, in_proj_bias, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
if attn_mask.numel() > 0:
if attn_mask.dtype == torch.uint8:
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dtype == torch.bool:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
elif attn_mask.dim() == 3:
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
if attn_mask.shape == correct_3d_size1:
attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
.reshape(batch_size * num_heads, seq_len, seq_len_k)
attn_mask = attn_mask.contiguous()
if key_padding_mask.numel() > 0:
if key_padding_mask.dtype == torch.uint8:
key_padding_mask = key_padding_mask.to(torch.bool)
if key_padding_mask.dtype == torch.bool:
key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
.reshape(-1, seq_len_k).unsqueeze(1).contiguous()
grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention_compact.backward(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(),
out_proj_bias.contiguous(), dropout_mask.contiguous(), attn_softmax.contiguous(),
query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
in_proj_bias.contiguous(), attn_mask, key_padding_mask,
num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), attnPre.contiguous(),
attn_mask_need_grad or key_mask_need_grad, index
)
grad_attn_mask = None
if attn_mask_need_grad:
attn_mask_shape = ctx.attn_mask_shape
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
grad_attn_mask = grad_mask
if attn_mask_shape == correct_3d_size1:
grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
grad_key_mask = None
if key_mask_need_grad:
grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
if batch_first:
grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
if need_softmax_temperature:
return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
else:
return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None
elif attn_method == 3:
output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask = ctx.saved_tensors
run_attn = ctx.run_attn
seeds = ctx.seeds
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias = run_attn.mha_backward_same_qkv(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
query.contiguous(), in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout, seeds)
grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
embed_dim).transpose(0, 1).contiguous().view(
3 * embed_dim, embed_dim)
grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
1).contiguous().view(
3 * embed_dim)
if batch_first:
grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
return grad_input, None, None, None, None, None, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None
elif attn_method == 4:
output, query, key, value, q1, k1, v1, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask = ctx.saved_tensors
run_attn = ctx.run_attn
seeds = ctx.seeds
grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias = run_attn.mha_backward(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
k1.contiguous(), v1.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(),
in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout, seeds)
if batch_first:
grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
return grad_query, grad_key, grad_value, None, None, None, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None
elif attn_method == 5:
output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask, scaleT = ctx.saved_tensors
seeds = ctx.seeds
if attn_mask.numel() > 0:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_mask, grad_scaleT = fused_mha.mha_backward_same_qkv(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
query.contiguous(), in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout,
attn_mask_need_grad or key_mask_need_grad,
need_softmax_temperature, scaleT, seeds)
if num_heads != 1:
grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
embed_dim).transpose(0, 1).contiguous().view(
3 * embed_dim, embed_dim)
grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
1).contiguous().view(
3 * embed_dim)
grad_attn_mask = None
if attn_mask_need_grad:
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
grad_attn_mask = grad_mask
if attn_mask.shape == correct_3d_size1:
grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
grad_key_mask = None
if key_mask_need_grad:
grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
grad_t = None
if need_softmax_temperature:
grad_t = grad_scaleT.reshape(num_heads, 1, 1)
if batch_first:
grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_t, None, None
elif attn_method == 6:
output, query, key, value, q1, k1, v1, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask, scaleT = ctx.saved_tensors
seeds = ctx.seeds
if attn_mask.numel() > 0:
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_mask, grad_scaleT = fused_mha.mha_backward(
grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
k1.contiguous(), v1.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(),
in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout,
attn_mask_need_grad or key_mask_need_grad,
need_softmax_temperature, scaleT, seeds)
grad_attn_mask = None
if attn_mask_need_grad:
correct_3d_size1 = (num_heads, seq_len, seq_len_k)
grad_attn_mask = grad_mask
if attn_mask.shape == correct_3d_size1:
grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
grad_key_mask = None
if key_mask_need_grad:
grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
grad_t = None
if need_softmax_temperature:
grad_t = grad_scaleT.reshape(num_heads, 1, 1)
if batch_first:
grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_t, None, None
class MyLinear(nn.Module):
def __init__(self, in_features, out_features, device=None, dtype=None):
super(MyLinear, self).__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))
[docs]class MultiheadAttention(nn.MultiheadAttention):
"""
更高效的MultiheadAttention算子
目前支持 ``qkv`` 的 ``embed_dim`` 相同且 ``add_zero_attn=False``,其余和 `PyTorch的MultiheadAttention <https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=multihead#torch.nn.MultiheadAttention>`_ 算子一致
.. note::
支持两种额外扩展模式
1) 低精度模式,进一步提升性能
指定 ``hfai.nn.context.SetAttnAllowConversion(True)`` 时, MultiheadAttention 在保证误差的情况下对于部分矩阵乘法降低精度进行运算, 进一步提升性能
2) 低显存模式,适应 batch_size 或者参数更大的训练
设置 MultiheadAttention 中 compact 参数为True (default: False)
Examples:
.. code-block:: python
attn = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1).cuda()
# 低显存模式
attn_compact = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1, compact=True).cuda()
query0 = torch.randn(5, 4, 10).cuda()
key0 = torch.randn(3, 4, 10).cuda()
value0 = torch.randn(3, 4, 10).cuda()
output0 = attn(query0, key0, value0)[0]
# 低精度模式
hfai.nn.context.SetAttnAllowConversion(True)
query1 = torch.randn(5, 4, 10).cuda()
key1 = torch.randn(3, 4, 10).cuda()
value1 = torch.randn(3, 4, 10).cuda()
output1 = attn(query1, key1, value1)[0]
hfai.nn.context.SetRnnAllowConversion(False)
"""
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None,
vdim=None,
batch_first=False, need_softmax_temperature=False, compact=False, device=None, dtype=None):
super(nn.MultiheadAttention, self).__init__()
assert no_attention == False and no_attention_half == False, '未找到hfai attention,请联系HFAI'
factory_kwargs = {'device': device, 'dtype': dtype}
self.embed_dim = embed_dim
self.kdim = kdim if kdim is not None else embed_dim
self.vdim = vdim if vdim is not None else embed_dim
self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
self.batch_first = batch_first
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.compact = compact
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if self._qkv_same_embed_dim is False:
self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs))
self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs))
self.register_parameter('in_proj_weight', None)
else:
self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs))
self.register_parameter('q_proj_weight', None)
self.register_parameter('k_proj_weight', None)
self.register_parameter('v_proj_weight', None)
if bias:
self.in_proj_bias = nn.Parameter(torch.empty((3 * embed_dim), **factory_kwargs))
else:
self.register_parameter('in_proj_bias', None)
self.out_proj = MyLinear(embed_dim, embed_dim, **factory_kwargs)
if add_bias_kv:
self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
else:
self.bias_k = self.bias_v = None
# scaleT = 1.0 / temperature
self.need_softmax_temperature = need_softmax_temperature
if need_softmax_temperature:
self.scaleT = nn.Parameter(torch.ones((num_heads, 1, 1), **factory_kwargs))
else:
self.scaleT = torch.ones(1, **factory_kwargs)
self.add_zero_attn = add_zero_attn
self._reset_parameters()
def _reset_parameters(self):
kaiming_uniform_(self.out_proj.weight, a=math.sqrt(5))
if self.out_proj.bias is not None:
fan_in, _ = _calculate_fan_in_and_fan_out(self.out_proj.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
uniform_(self.out_proj.bias, -bound, bound)
if self._qkv_same_embed_dim:
xavier_uniform_(self.in_proj_weight)
# self.in_proj_weight.data = self.in_proj_weight.data.reshape(3, self.num_heads, self.embed_dim//self.num_heads, self.embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * self.embed_dim, self.embed_dim)
else:
xavier_uniform_(self.q_proj_weight)
xavier_uniform_(self.k_proj_weight)
xavier_uniform_(self.v_proj_weight)
if self.in_proj_bias is not None:
constant_(self.in_proj_bias, 0.)
constant_(self.out_proj.bias, 0.)
if self.bias_k is not None:
xavier_normal_(self.bias_k)
if self.bias_v is not None:
xavier_normal_(self.bias_v)
def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
if not (query.is_cuda and query.dtype == torch.float32):
return super().forward(query, key, value, key_padding_mask, need_weights, attn_mask)
return MultiheadAttentionFunc.apply(query, key, value, key_padding_mask, need_weights, attn_mask,
self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias,
self.bias_k, self.bias_v, self.add_zero_attn,
self.dropout if self.training else 0.0,
self.out_proj.weight, self.out_proj.bias, self._qkv_same_embed_dim,
self.batch_first, self.need_softmax_temperature, self.scaleT, self.compact,
torch.is_grad_enabled() and self.training)
def inference_init(self, max_seq_len):
...
def inference_increment(self,
qkv: Tensor,
max_seq_len: int,
now_seq_len: int,
cache: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""
高效的 attention 增量 inference 函数
Examples:
.. code-block:: python
mha = hfai.nn.MultiheadAttention(128, 4).cuda()
input = torch.randn(64, 100, 128).cuda()
output, cache = mha.inference_increment(input, max_seq_len=128, now_seq_len=64)
for _ in range(64):
input = torch.randn(1, 100, 128).cuda()
output, cache = mha.inference_increment(input, max_seq_len=128, now_seq_len=64, cache=cache)
"""
if not self.batch_first:
qkv = qkv.transpose(1, 0)
batch_size, seq_len, embed_dim = qkv.shape
head_dim = embed_dim // self.num_heads
if cache is None:
# the first inference
cache = torch.empty(batch_size, max_seq_len, embed_dim, dtype=qkv.dtype, device=qkv.device)
cache[:, now_seq_len - seq_len:now_seq_len, :] = qkv
wq, wk, wv = self.in_proj_weight.chunk(3, dim=0)
bq, bk, bv = self.in_proj_bias.chunk(3, dim=0)
q = F.linear(qkv[:, -1:, :], wq, bq)
q = q.view(batch_size, self.num_heads, 1, head_dim)
q = q.contiguous()
k = F.linear(cache[:, :now_seq_len, :], wk, bk)
k = k.view(batch_size, now_seq_len, self.num_heads, head_dim).transpose(1, 2)
attn = q @ k.transpose(-2, -1)
del q, k
attn.mul_(1.0 / math.sqrt(head_dim))
if self.need_softmax_temperature:
attn.mul_(self.scaleT)
attn = F.softmax(attn, dim=-1)
v = F.linear(cache[:, :now_seq_len, :], wv, bv)
v = v.view(batch_size, now_seq_len, self.num_heads, head_dim).transpose(1, 2)
out = attn @ v
del attn, v
out = out.reshape(batch_size, 1, embed_dim) # [batch_size, 1, embed_dim]
out = F.linear(out, self.out_proj.weight, self.out_proj.bias) # [batch_size, 1, embed_dim]
if not self.batch_first:
out = out.transpose(0, 1)
return out, cache