Shortcuts

Source code for hfai.nn.modules.fast_multihead_attention

from typing import Tuple, Optional
import torch
import torch.nn as nn
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_, kaiming_uniform_, _calculate_fan_in_and_fan_out, \
    uniform_
import math
import sys
from .context import context

no_attention = False
no_attention_half = False
no_attention_compact = False

try:
    import hfai.hfcuda.nn_attention as attention
except:
    no_attention = True

try:
    import hfai.hfcuda.nn_attention_half as attention_half
except:
    no_attention_half = True

try:
    import hfai.hfcuda.nn_attention_compact as attention_compact
except:
    no_attention_compact = True


def get_attention():
    if context.GetAttnAllowConversion():
        return attention_half
    else:
        return attention


class MultiheadAttentionFunc(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, query, key, value, key_padding_mask, need_weights, attn_mask,
                embed_dim, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
                add_zero_attn, dropout, out_proj_weight, out_proj_bias, _qkv_same_embed_dim, batch_first,
                need_softmax_temperature,
                scaleT, compact,
                training):

        assert bias_k is None, \
            'hfai.nn.MultiheadAttention不支持bias_k'
        assert bias_v is None, \
            'hfai.nn.MultiheadAttention不支持bias_v'
        assert add_zero_attn == False, \
            'hfai.nn.MultiheadAttention不支持add_zero_attn'
        assert query.dtype == torch.float32, \
            'hfai.nn.MultiheadAttention暂时只支持float32'

        if batch_first:
            query, key, value = [x.transpose(0, 1) for x in (query, key, value)]

        seq_len, batch_size, embed_dim = query.size()
        seq_len_k, _, _ = key.size()

        if _qkv_same_embed_dim:
            attn_mask_need_grad = False
            key_mask_need_grad = False
            if attn_mask is not None:
                assert attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool or attn_mask.dtype == torch.float32, \
                    'hfai.nn.MultiheadAttention的attn_mask只支持uint8,bool和float32, not {}'.format(attn_mask.dtype)
                if attn_mask.dtype == torch.uint8:
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dtype == torch.float32:
                    attn_mask_need_grad = attn_mask.requires_grad
                if attn_mask.dim() == 2:
                    correct_2d_size = (seq_len, seq_len_k)
                    if attn_mask.shape != correct_2d_size:
                        raise RuntimeError(
                            f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
                    attn_mask = attn_mask.unsqueeze(0)
                elif attn_mask.dim() == 3:
                    correct_3d_size = (batch_size * num_heads, seq_len, seq_len_k)
                    correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                    if attn_mask.shape != correct_3d_size and attn_mask.shape != correct_3d_size1:
                        raise RuntimeError(
                            f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size} or {correct_3d_size1}.")
                    if attn_mask.shape == correct_3d_size1:
                        attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k).reshape(batch_size * num_heads, seq_len, seq_len_k)
                else:
                    raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

            if key_padding_mask is not None:
                assert key_padding_mask.dtype == torch.uint8 or key_padding_mask.dtype == torch.bool or key_padding_mask.dtype == torch.float32, \
                    'hfai.nn.MultiheadAttention的key_padding_mask只支持uint8,bool和float32, not {}'.format(
                        key_padding_mask.dtype)
                assert key_padding_mask.shape == (batch_size, seq_len_k), \
                    f"expecting key_padding_mask shape of {(batch_size, seq_len_k)}, but got {key_padding_mask.shape}"
                if key_padding_mask.dtype == torch.uint8:
                    key_padding_mask = key_padding_mask.to(torch.bool)
                if key_padding_mask.dtype == torch.float32:
                    key_mask_need_grad = key_padding_mask.requires_grad
                key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k).reshape(-1,
                                                                                                                  seq_len_k).unsqueeze(
                    1)

            if attn_mask is not None and key_padding_mask is not None:
                assert attn_mask.dtype == key_padding_mask.dtype, "attn_mask 和 key_padding_mask 类型要一致"

            if compact:

                if training:
                    index = query.device.index
                    if attn_mask is not None:
                        attn_mask = attn_mask.contiguous()
                    else:
                        attn_mask = torch.tensor([])
                    if key_padding_mask is not None:
                        key_padding_mask = key_padding_mask.contiguous()
                    else:
                        key_padding_mask = torch.tensor([])

                    output, attnPre, dropout_mask, attn_softmax, output2 = attention_compact.forward(
                        query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                        in_proj_bias.contiguous(),
                        out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                        num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)

                    ctx.num_heads = num_heads
                    ctx.dropout = dropout
                    ctx.embed_dim = embed_dim
                    ctx.index = index
                    ctx.batch_first = batch_first
                    ctx.batch_size = batch_size
                    ctx.seq_len = seq_len
                    ctx.seq_len_k = seq_len_k
                    ctx.need_softmax_temperature = need_softmax_temperature
                    ctx.attn_mask_need_grad = attn_mask_need_grad
                    ctx.key_mask_need_grad = key_mask_need_grad
                    ctx.attn_method = 2
                    ctx.save_for_backward(output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax,
                                          query,
                                          key, value, in_proj_weight, in_proj_bias, attn_mask, key_padding_mask, scaleT)
                else:
                    index = query.device.index
                    if attn_mask is not None:
                        attn_mask = attn_mask.contiguous()
                    else:
                        attn_mask = torch.tensor([])
                    if key_padding_mask is not None:
                        key_padding_mask = key_padding_mask.contiguous()
                    else:
                        key_padding_mask = torch.tensor([])
                    output2, attn = attention_compact.inference(
                        query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                        in_proj_bias.contiguous(),
                        out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                        num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)
            else:

                if torch.equal(query, key) and torch.equal(query, value):
                    if training:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
                                                                embed_dim).transpose(0, 1).contiguous().reshape(
                            3 * embed_dim, embed_dim)
                        in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
                                                                                                            1).contiguous().reshape(
                            3 * embed_dim)
                        run_attn = get_attention()
                        output, qkv, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward(
                            query.contiguous(), in_proj_weight.contiguous(), in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(),
                            attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

                        ctx.run_attn = run_attn
                        ctx.num_heads = num_heads
                        ctx.embed_dim = embed_dim
                        ctx.dropout = dropout
                        ctx.batch_first = batch_first
                        ctx.batch_size = batch_size
                        ctx.seq_len = seq_len
                        ctx.seq_len_k = seq_len_k
                        ctx.index = index
                        ctx.need_softmax_temperature = need_softmax_temperature
                        ctx.attn_mask_need_grad = attn_mask_need_grad
                        ctx.key_mask_need_grad = key_mask_need_grad
                        ctx.attn_method = 0
                        ctx.save_for_backward(output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn,
                                              query,
                                              in_proj_weight, attn_mask, key_padding_mask, scaleT)

                    else:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
                                                                embed_dim).transpose(0, 1).contiguous().reshape(
                            3 * embed_dim, embed_dim)
                        in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
                                                                                                            1).contiguous().reshape(
                            3 * embed_dim)
                        run_attn = get_attention()
                        output2, attn = run_attn.inference(query.contiguous(), in_proj_weight.contiguous(),
                                                           in_proj_bias.contiguous(), out_proj_weight.contiguous(),
                                                           out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                                                           num_heads, dropout, need_softmax_temperature,
                                                           scaleT.contiguous(), index)

                else:
                    if training:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        # in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                        # in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
                        run_attn = get_attention()
                        output, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward_diff_qkv(
                            query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                            in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

                        ctx.run_attn = run_attn
                        ctx.num_heads = num_heads
                        ctx.dropout = dropout
                        ctx.embed_dim = embed_dim
                        ctx.index = index
                        ctx.batch_first = batch_first
                        ctx.batch_size = batch_size
                        ctx.seq_len = seq_len
                        ctx.seq_len_k = seq_len_k
                        ctx.need_softmax_temperature = need_softmax_temperature
                        ctx.attn_mask_need_grad = attn_mask_need_grad
                        ctx.key_mask_need_grad = key_mask_need_grad
                        ctx.attn_method = 1
                        ctx.save_for_backward(output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax,
                                              attn,
                                              query, key, value, in_proj_weight, attn_mask, key_padding_mask, scaleT)

                    else:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        # in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                        # in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
                        run_attn = get_attention()
                        output2, attn = run_attn.inference_diff_qkv(
                            query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                            in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

            if batch_first:
                output2 = output2.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)

            if need_weights:
                if compact and training:
                    # attn, _ = torch._fused_dropout(attn_softmax, p=(1.-dropout))
                    if dropout > 0.0001:
                        attn = torch._masked_scale(attn_softmax, dropout_mask, 1.0 / (1.0 - dropout))
                    else:
                        attn = attn_softmax
                attn = attn.view(batch_size, num_heads, seq_len, seq_len_k)
                return output2, attn.sum(dim=1) / num_heads
            else:
                return output2, torch.ones(1, device=index)

    @staticmethod
    @torch.cuda.amp.custom_bwd
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_output2, grad_weight):
        num_heads = ctx.num_heads
        embed_dim = ctx.embed_dim
        dropout = ctx.dropout
        index = ctx.index
        batch_first = ctx.batch_first
        batch_size = ctx.batch_size
        seq_len = ctx.seq_len
        seq_len_k = ctx.seq_len_k
        need_softmax_temperature = ctx.need_softmax_temperature
        attn_mask_need_grad = ctx.attn_mask_need_grad
        key_mask_need_grad = ctx.key_mask_need_grad
        attn_method = ctx.attn_method

        if batch_first:
            grad_output2 = grad_output2.transpose(0, 1)

        if attn_method == 0:
            output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn, inputs, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
                dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), inputs.contiguous(),
                in_proj_weight.contiguous(),
                attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature, scaleT.contiguous(),
                attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
            )
            grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
                                                              embed_dim).transpose(0, 1).contiguous().view(
                3 * embed_dim, embed_dim)
            grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
                                                                                                          1).contiguous().view(
                3 * embed_dim)

            grad_attn_mask = None
            if attn_mask_need_grad:
                grad_attn_mask = grad_mask
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
            if need_softmax_temperature:
                return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 1:
            output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, query, key, value, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward_diff_qkv(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
                k1.contiguous(), v1.contiguous(),
                dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), query.contiguous(),
                key.contiguous(), value.contiguous(),
                in_proj_weight.contiguous(), attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature,
                scaleT.contiguous(), attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
            )
            # grad_in_proj_weight = grad_in_proj_weight.reshape(3, num_heads, embed_dim//num_heads, embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * embed_dim, embed_dim)
            # grad_in_proj_bias = grad_in_proj_bias.reshape(3, num_heads, embed_dim//num_heads).permute(1, 0, 2).contiguous().view(3 * embed_dim)

            grad_attn_mask = None
            if attn_mask_need_grad:
                grad_attn_mask = grad_mask
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            if need_softmax_temperature:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 2:
            output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax, query, key, value, in_proj_weight, in_proj_bias, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention_compact.backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(),
                out_proj_bias.contiguous(), dropout_mask.contiguous(), attn_softmax.contiguous(),
                query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                in_proj_bias.contiguous(), attn_mask, key_padding_mask,
                num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), attnPre.contiguous(),
                attn_mask_need_grad or key_mask_need_grad, index
            )

            grad_attn_mask = None
            if attn_mask_need_grad:
                grad_attn_mask = grad_mask
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            if need_softmax_temperature:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None


class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super(MyLinear, self).__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))


[docs]class MultiheadAttention(nn.MultiheadAttention): """ 更高效的MultiheadAttention算子 目前支持 ``qkv`` 的 ``embed_dim`` 相同且 ``add_zero_attn=False``,其余和 `PyTorch的MultiheadAttention <https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=multihead#torch.nn.MultiheadAttention>`_ 算子一致 .. note:: 支持两种额外扩展模式 1) 低精度模式,进一步提升性能 指定 ``hfai.nn.context.SetAttnAllowConversion(True)`` 时, MultiheadAttention 在保证误差的情况下对于部分矩阵乘法降低精度进行运算, 进一步提升性能 2) 低显存模式,适应 batch_size 或者参数更大的训练 设置 MultiheadAttention 中 compact 参数为True (default: False) Examples: .. code-block:: python attn = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1).cuda() # 低显存模式 attn_compact = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1, compact=True).cuda() query0 = torch.randn(5, 4, 10).cuda() key0 = torch.randn(3, 4, 10).cuda() value0 = torch.randn(3, 4, 10).cuda() output0 = attn(query0, key0, value0)[0] # 低精度模式 hfai.nn.context.SetAttnAllowConversion(True) query1 = torch.randn(5, 4, 10).cuda() key1 = torch.randn(3, 4, 10).cuda() value1 = torch.randn(3, 4, 10).cuda() output1 = attn(query1, key1, value1)[0] hfai.nn.context.SetRnnAllowConversion(False) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, need_softmax_temperature=False, compact=False, device=None, dtype=None): super(nn.MultiheadAttention, self).__init__() assert no_attention == False and no_attention_half == False, '未找到hfai attention,请联系HFAI' factory_kwargs = {'device': device, 'dtype': dtype} self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.batch_first = batch_first self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.compact = compact assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = nn.Parameter(torch.empty((3 * embed_dim), **factory_kwargs)) else: self.register_parameter('in_proj_bias', None) self.out_proj = MyLinear(embed_dim, embed_dim, **factory_kwargs) if add_bias_kv: self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None # scaleT = 1.0 / temperature self.need_softmax_temperature = need_softmax_temperature if need_softmax_temperature: self.scaleT = nn.Parameter(torch.ones((num_heads, 1, 1))) else: self.scaleT = torch.ones(1) self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): kaiming_uniform_(self.out_proj.weight, a=math.sqrt(5)) if self.out_proj.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(self.out_proj.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(self.out_proj.bias, -bound, bound) if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) # self.in_proj_weight.data = self.in_proj_weight.data.reshape(3, self.num_heads, self.embed_dim//self.num_heads, self.embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * self.embed_dim, self.embed_dim) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): if not query.is_cuda: return super().forward(query, key, value, key_padding_mask, need_weights, attn_mask) return MultiheadAttentionFunc.apply(query, key, value, key_padding_mask, need_weights, attn_mask, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout if self.training else 0.0, self.out_proj.weight, self.out_proj.bias, self._qkv_same_embed_dim, self.batch_first, self.need_softmax_temperature, self.scaleT, self.compact, torch.is_grad_enabled() and self.training)