Shortcuts

Source code for hfai.nn.modules.fast_multihead_attention

from typing import Tuple, Union, Optional
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn.init import xavier_uniform_, xavier_normal_, constant_, kaiming_uniform_, _calculate_fan_in_and_fan_out, \
    uniform_
import math
from .context import context
import torch.nn.functional as F

no_attention = False
no_attention_half = False
no_attention_compact = False
no_flash_attn_fp32 = False
no_flash_attn_tf32 = False
no_fused_attn = False
no_qkv_inference = False

try:
    import hfai.hfcuda.nn_attention as attention
except:
    no_attention = True

try:
    import hfai.hfcuda.nn_attention_half as attention_half
except:
    no_attention_half = True

try:
    import hfai.hfcuda.nn_attention_compact as attention_compact
except:
    no_attention_compact = True

try:
    import hfai.hfcuda.flash_attention_fp32 as flash_attn_fp32
except:
    no_flash_attn_fp32 = True

try:
    import hfai.hfcuda.flash_attention_tf32 as flash_attn_tf32
except:
    no_flash_attn_tf32 = True

try:
    import hfai.hfcuda.fused_mha as fused_mha
except:
    no_fused_mha = True

try:
    import hfai.hfcuda.qkv_inference as qkv_infer
except:
    no_qkv_inference = True


def get_attention():
    if context.GetAttnAllowConversion():
        return attention_half
    else:
        return attention


class MultiheadAttentionFunc(torch.autograd.Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, query, key, value, key_padding_mask1, need_weights, attn_mask1,
                embed_dim, num_heads, in_proj_weight, in_proj_bias, bias_k, bias_v,
                add_zero_attn, dropout, out_proj_weight, out_proj_bias, _qkv_same_embed_dim, batch_first,
                need_softmax_temperature,
                scaleT, compact,
                training):

        assert bias_k is None, \
            'hfai.nn.MultiheadAttention不支持bias_k'
        assert bias_v is None, \
            'hfai.nn.MultiheadAttention不支持bias_v'
        assert add_zero_attn == False, \
            'hfai.nn.MultiheadAttention不支持add_zero_attn'
        assert query.dtype == torch.float32, \
            'hfai.nn.MultiheadAttention暂时只支持float32'

        same_qkv = False
        if batch_first:
            if key is value:
                if query is key:
                    query = key = value = query.transpose(1, 0)
                    same_qkv = True
                else:
                    query, key = [x.transpose(1, 0) for x in (query, key)]
                    value = key
            else:
                query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
        else:
            if key is value and query is key:
                same_qkv = True

        seq_len, batch_size, embed_dim = query.size()
        seq_len_k, _, _ = key.size()

        attn_mask = attn_mask1
        key_padding_mask = key_padding_mask1
        if attn_mask1 is None:
            attn_mask1 = torch.tensor([])
        if key_padding_mask1 is None:
            key_padding_mask1 = torch.tensor([])

        if _qkv_same_embed_dim:
            can_use_fused = True
            if (seq_len != 64) or (
                    embed_dim // num_heads != 64 and embed_dim // num_heads != 96 and embed_dim // num_heads != 128) or need_weights or (
                    torch.backends.cuda.matmul.allow_tf32 == False):
                can_use_fused = False

            if attn_mask is not None:
                if attn_mask.dtype != torch.float32:
                    can_use_fused = False

            if key_padding_mask is not None:
                if key_padding_mask != torch.float32:
                    can_use_fused = False

            if can_use_fused:
                index = query.device.index

                attn_mask_need_grad = False
                key_mask_need_grad = False
                if attn_mask1.numel() > 0:
                    if attn_mask1.dim() == 2:
                        attn_mask = attn_mask.unsqueeze(0)
                    attn_mask_need_grad = attn_mask1.requires_grad
                if key_padding_mask1.numel() > 0:
                    key_mask_need_grad = key_padding_mask1.requires_grad

                if attn_mask is None:
                    attn_mask = torch.tensor([])
                if key_padding_mask is None:
                    key_padding_mask = torch.tensor([])

                if training:
                    ctx.num_heads = num_heads
                    ctx.dropout = dropout
                    ctx.embed_dim = embed_dim
                    ctx.index = index
                    ctx.batch_first = batch_first
                    ctx.batch_size = batch_size
                    ctx.seq_len = seq_len
                    ctx.seq_len_k = seq_len_k
                    ctx.need_softmax_temperature = need_softmax_temperature
                    ctx.attn_mask_need_grad = attn_mask_need_grad
                    ctx.key_mask_need_grad = key_mask_need_grad

                if same_qkv:
                    if num_heads != 1:
                        in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads, embed_dim) \
                            .transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                        in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads) \
                            .transpose(0, 1).contiguous().reshape(3 * embed_dim)
                    out_tensors, seeds = fused_mha.mha_forward_same_qkv(query.contiguous(), in_proj_weight.contiguous(),
                                                                        in_proj_bias.contiguous(),
                                                                        out_proj_weight.contiguous(),
                                                                        out_proj_bias.contiguous(),
                                                                        key_padding_mask.contiguous(),
                                                                        attn_mask.contiguous(),
                                                                        num_heads, dropout, need_softmax_temperature,
                                                                        scaleT, False, training)
                    output, qkv, output2, softmax_sum, softmax_max = out_tensors
                    ctx.attn_method = 5
                    ctx.seeds = seeds
                    ctx.save_for_backward(output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight,
                                          key_padding_mask1, attn_mask1, scaleT)
                else:
                    out_tensors, seeds = fused_mha.mha_forward(query.contiguous(), key.contiguous(), value.contiguous(),
                                                               in_proj_weight.contiguous(), in_proj_bias.contiguous(),
                                                               out_proj_weight.contiguous(), out_proj_bias.contiguous(),
                                                               key_padding_mask.contiguous(), attn_mask.contiguous(),
                                                               num_heads, dropout, need_softmax_temperature, scaleT,
                                                               False, training)
                    output, q1, k1, v1, output2, softmax_sum, softmax_max = out_tensors
                    ctx.attn_method = 6
                    ctx.seeds = seeds
                    ctx.save_for_backward(output, query, key, value, q1, k1, v1, softmax_sum, softmax_max,
                                          in_proj_weight, out_proj_weight, key_padding_mask1, attn_mask1, scaleT)

                return output2, torch.ones(1, device=query.device)

            can_use_flash = True
            if torch.backends.cuda.matmul.allow_tf32 and embed_dim // num_heads > 128:
                can_use_flash = False
            if torch.backends.cuda.matmul.allow_tf32 == False and embed_dim // num_heads > 64:
                can_use_flash = False
            if need_weights or need_softmax_temperature:
                can_use_flash = False
            if key_padding_mask1.numel() > 0:
                if (key_padding_mask1.shape != (batch_size, seq_len_k)) or (key_padding_mask1.requires_grad == True) \
                        or (key_padding_mask1.dtype != torch.float32):
                    can_use_flash = False
            if attn_mask1.numel() > 0:
                if (attn_mask1.dtype != torch.float32) or (attn_mask1.requires_grad == True):
                    can_use_flash = False

            if can_use_flash:
                run_attn = None
                if torch.backends.cuda.matmul.allow_tf32:
                    run_attn = flash_attn_tf32
                else:
                    run_attn = flash_attn_fp32
                index = query.device.index

                if attn_mask1.numel() > 0:
                    if attn_mask1.dim() == 2:
                        attn_mask1 = attn_mask1.unsqueeze(0)

                if training:
                    ctx.num_heads = num_heads
                    ctx.dropout = dropout
                    ctx.embed_dim = embed_dim
                    ctx.index = index
                    ctx.batch_first = batch_first
                    ctx.batch_size = batch_size
                    ctx.seq_len = seq_len
                    ctx.seq_len_k = seq_len_k
                    ctx.need_softmax_temperature = need_softmax_temperature
                    ctx.run_attn = run_attn
                    ctx.attn_mask_need_grad = False
                    ctx.key_mask_need_grad = False

                if same_qkv:
                    in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads, embed_dim) \
                        .transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                    in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads) \
                        .transpose(0, 1).contiguous().reshape(3 * embed_dim)
                    out_tensors, seeds = run_attn.mha_forward_same_qkv(query.contiguous(), in_proj_weight.contiguous(),
                                                                       in_proj_bias.contiguous(),
                                                                       out_proj_weight.contiguous(),
                                                                       out_proj_bias.contiguous(),
                                                                       key_padding_mask1.contiguous(),
                                                                       attn_mask1.contiguous(),
                                                                       num_heads, dropout, False, training)
                    output, qkv, output2, softmax_sum, softmax_max = out_tensors
                    ctx.attn_method = 3
                    ctx.seeds = seeds
                    ctx.save_for_backward(output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight,
                                          key_padding_mask1, attn_mask1)
                else:
                    out_tensors, seeds = run_attn.mha_forward(query.contiguous(), key.contiguous(), value.contiguous(),
                                                              in_proj_weight.contiguous(), in_proj_bias.contiguous(),
                                                              out_proj_weight.contiguous(), out_proj_bias.contiguous(),
                                                              key_padding_mask1.contiguous(), attn_mask1.contiguous(),
                                                              num_heads, dropout, False, training)
                    output, q1, k1, v1, output2, softmax_sum, softmax_max = out_tensors
                    ctx.attn_method = 4
                    ctx.seeds = seeds
                    ctx.save_for_backward(output, query, key, value, q1, k1, v1, softmax_sum, softmax_max,
                                          in_proj_weight, out_proj_weight, key_padding_mask1, attn_mask1)

                return output2, torch.ones(1, device=query.device)

            attn_mask_need_grad = False
            key_mask_need_grad = False
            if attn_mask is not None:
                assert attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool or attn_mask.dtype == torch.float32, \
                    'hfai.nn.MultiheadAttention的attn_mask只支持uint8,bool和float32, not {}'.format(attn_mask.dtype)
                if attn_mask.dtype == torch.uint8:
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dtype == torch.float32:
                    attn_mask_need_grad = attn_mask.requires_grad
                ctx.attn_mask_shape = attn_mask.shape
                if attn_mask.dim() == 2:
                    correct_2d_size = (seq_len, seq_len_k)
                    if attn_mask.shape != correct_2d_size:
                        raise RuntimeError(
                            f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
                    attn_mask = attn_mask.unsqueeze(0)
                elif attn_mask.dim() == 3:
                    correct_3d_size = (batch_size * num_heads, seq_len, seq_len_k)
                    correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                    if attn_mask.shape != correct_3d_size and attn_mask.shape != correct_3d_size1:
                        raise RuntimeError(
                            f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size} or {correct_3d_size1}.")
                    if attn_mask.shape == correct_3d_size1:
                        attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
                            .reshape(batch_size * num_heads, seq_len, seq_len_k)
                else:
                    raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")

            if key_padding_mask is not None:
                assert key_padding_mask.dtype == torch.uint8 or key_padding_mask.dtype == torch.bool or key_padding_mask.dtype == torch.float32, \
                    'hfai.nn.MultiheadAttention的key_padding_mask只支持uint8,bool和float32, not {}'.format(
                        key_padding_mask.dtype)
                assert key_padding_mask.shape == (batch_size, seq_len_k), \
                    f"expecting key_padding_mask shape of {(batch_size, seq_len_k)}, but got {key_padding_mask.shape}"
                if key_padding_mask.dtype == torch.uint8:
                    key_padding_mask = key_padding_mask.to(torch.bool)
                if key_padding_mask.dtype == torch.float32:
                    key_mask_need_grad = key_padding_mask.requires_grad
                key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k).reshape(-1,
                                                                                                                  seq_len_k).unsqueeze(
                    1)

            if attn_mask is not None and key_padding_mask is not None:
                assert attn_mask.dtype == key_padding_mask.dtype, "attn_mask 和 key_padding_mask 类型要一致"

            if compact:

                if training:
                    index = query.device.index
                    if attn_mask is not None:
                        attn_mask = attn_mask.contiguous()
                    else:
                        attn_mask = torch.tensor([])
                    if key_padding_mask is not None:
                        key_padding_mask = key_padding_mask.contiguous()
                    else:
                        key_padding_mask = torch.tensor([])

                    output, attnPre, dropout_mask, attn_softmax, output2 = attention_compact.forward(
                        query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                        in_proj_bias.contiguous(),
                        out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                        num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)

                    ctx.num_heads = num_heads
                    ctx.dropout = dropout
                    ctx.embed_dim = embed_dim
                    ctx.index = index
                    ctx.batch_first = batch_first
                    ctx.batch_size = batch_size
                    ctx.seq_len = seq_len
                    ctx.seq_len_k = seq_len_k
                    ctx.need_softmax_temperature = need_softmax_temperature
                    ctx.attn_mask_need_grad = attn_mask_need_grad
                    ctx.key_mask_need_grad = key_mask_need_grad
                    ctx.attn_method = 2
                    ctx.save_for_backward(output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax,
                                          query,
                                          key, value, in_proj_weight, in_proj_bias, attn_mask1, key_padding_mask1,
                                          scaleT)
                else:
                    index = query.device.index
                    if attn_mask is not None:
                        attn_mask = attn_mask.contiguous()
                    else:
                        attn_mask = torch.tensor([])
                    if key_padding_mask is not None:
                        key_padding_mask = key_padding_mask.contiguous()
                    else:
                        key_padding_mask = torch.tensor([])
                    output2, attn = attention_compact.inference(
                        query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                        in_proj_bias.contiguous(),
                        out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                        num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index)
            else:

                if same_qkv:
                    if training:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
                                                                embed_dim).transpose(0, 1).contiguous().reshape(
                            3 * embed_dim, embed_dim)
                        in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
                                                                                                            1).contiguous().reshape(
                            3 * embed_dim)
                        run_attn = get_attention()
                        output, qkv, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward(
                            query.contiguous(), in_proj_weight.contiguous(), in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(),
                            attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

                        ctx.run_attn = run_attn
                        ctx.num_heads = num_heads
                        ctx.embed_dim = embed_dim
                        ctx.dropout = dropout
                        ctx.batch_first = batch_first
                        ctx.batch_size = batch_size
                        ctx.seq_len = seq_len
                        ctx.seq_len_k = seq_len_k
                        ctx.index = index
                        ctx.need_softmax_temperature = need_softmax_temperature
                        ctx.attn_mask_need_grad = attn_mask_need_grad
                        ctx.key_mask_need_grad = key_mask_need_grad
                        ctx.attn_method = 0
                        ctx.save_for_backward(output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn,
                                              query,
                                              in_proj_weight, attn_mask1, key_padding_mask1, scaleT)

                    else:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        in_proj_weight = in_proj_weight.reshape(3, num_heads, embed_dim // num_heads,
                                                                embed_dim).transpose(0, 1).contiguous().reshape(
                            3 * embed_dim, embed_dim)
                        in_proj_bias = in_proj_bias.reshape(3, num_heads, embed_dim // num_heads).transpose(0,
                                                                                                            1).contiguous().reshape(
                            3 * embed_dim)
                        run_attn = get_attention()
                        output2, attn = run_attn.inference(query.contiguous(), in_proj_weight.contiguous(),
                                                           in_proj_bias.contiguous(), out_proj_weight.contiguous(),
                                                           out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                                                           num_heads, dropout, need_softmax_temperature,
                                                           scaleT.contiguous(), index)

                else:
                    if training:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        # in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                        # in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
                        run_attn = get_attention()
                        output, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, output2 = run_attn.forward_diff_qkv(
                            query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                            in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

                        ctx.run_attn = run_attn
                        ctx.num_heads = num_heads
                        ctx.dropout = dropout
                        ctx.embed_dim = embed_dim
                        ctx.index = index
                        ctx.batch_first = batch_first
                        ctx.batch_size = batch_size
                        ctx.seq_len = seq_len
                        ctx.seq_len_k = seq_len_k
                        ctx.need_softmax_temperature = need_softmax_temperature
                        ctx.attn_mask_need_grad = attn_mask_need_grad
                        ctx.key_mask_need_grad = key_mask_need_grad
                        ctx.attn_method = 1
                        ctx.save_for_backward(output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax,
                                              attn,
                                              query, key, value, in_proj_weight, attn_mask1, key_padding_mask1, scaleT)

                    else:
                        index = query.device.index
                        if attn_mask is not None:
                            attn_mask = attn_mask.contiguous()
                        else:
                            attn_mask = torch.tensor([])
                        if key_padding_mask is not None:
                            key_padding_mask = key_padding_mask.contiguous()
                        else:
                            key_padding_mask = torch.tensor([])
                        # in_proj_weight = in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads, embed_dim).transpose(0, 1).contiguous().reshape(3 * embed_dim, embed_dim)
                        # in_proj_bias = in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0, 1).contiguous().reshape(3 * embed_dim)
                        run_attn = get_attention()
                        output2, attn = run_attn.inference_diff_qkv(
                            query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                            in_proj_bias.contiguous(),
                            out_proj_weight.contiguous(), out_proj_bias.contiguous(), attn_mask, key_padding_mask,
                            num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), index
                        )

            if batch_first:
                output2 = output2.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)

            if need_weights:
                if compact and training:
                    # attn, _ = torch._fused_dropout(attn_softmax, p=(1.-dropout))
                    if dropout > 0.0001:
                        attn = torch._masked_scale(attn_softmax, dropout_mask, 1.0 / (1.0 - dropout))
                    else:
                        attn = attn_softmax
                attn = attn.view(batch_size, num_heads, seq_len, seq_len_k)
                return output2, attn.sum(dim=1) / num_heads
            else:
                return output2, torch.ones(1, device=index)

    @staticmethod
    @torch.cuda.amp.custom_bwd
    @torch.autograd.function.once_differentiable
    def backward(ctx, grad_output2, grad_weight):
        num_heads = ctx.num_heads
        embed_dim = ctx.embed_dim
        dropout = ctx.dropout
        index = ctx.index
        batch_first = ctx.batch_first
        batch_size = ctx.batch_size
        seq_len = ctx.seq_len
        seq_len_k = ctx.seq_len_k
        need_softmax_temperature = ctx.need_softmax_temperature
        attn_mask_need_grad = ctx.attn_mask_need_grad
        key_mask_need_grad = ctx.key_mask_need_grad
        attn_method = ctx.attn_method

        if batch_first:
            grad_output2 = grad_output2.transpose(0, 1)

        if attn_method == 0:
            output, out_proj_weight, qkv, attnPre, dropout_mask, attn_softmax, attn, inputs, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            if attn_mask.numel() > 0:
                if attn_mask.dtype == torch.uint8:
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dtype == torch.bool:
                    if attn_mask.dim() == 2:
                        attn_mask = attn_mask.unsqueeze(0)
                    elif attn_mask.dim() == 3:
                        correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                        if attn_mask.shape == correct_3d_size1:
                            attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
                                .reshape(batch_size * num_heads, seq_len, seq_len_k)
                attn_mask = attn_mask.contiguous()
            if key_padding_mask.numel() > 0:
                if key_padding_mask.dtype == torch.uint8:
                    key_padding_mask = key_padding_mask.to(torch.bool)
                if key_padding_mask.dtype == torch.bool:
                    key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
                        .reshape(-1, seq_len_k).unsqueeze(1).contiguous()

            grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
                dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), inputs.contiguous(),
                in_proj_weight.contiguous(),
                attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature, scaleT.contiguous(),
                attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
            )
            grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
                                                              embed_dim).transpose(0, 1).contiguous().view(
                3 * embed_dim, embed_dim)
            grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
                                                                                                          1).contiguous().view(
                3 * embed_dim)

            grad_attn_mask = None
            if attn_mask_need_grad:
                attn_mask_shape = ctx.attn_mask_shape
                correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                grad_attn_mask = grad_mask
                if attn_mask_shape == correct_3d_size1:
                    grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
            if need_softmax_temperature:
                return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 1:
            output, out_proj_weight, q1, k1, v1, attnPre, dropout_mask, attn_softmax, attn, query, key, value, in_proj_weight, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            if attn_mask.numel() > 0:
                if attn_mask.dtype == torch.uint8:
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dtype == torch.bool:
                    if attn_mask.dim() == 2:
                        attn_mask = attn_mask.unsqueeze(0)
                    elif attn_mask.dim() == 3:
                        correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                        if attn_mask.shape == correct_3d_size1:
                            attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
                                .reshape(batch_size * num_heads, seq_len, seq_len_k)
                attn_mask = attn_mask.contiguous()
            if key_padding_mask.numel() > 0:
                if key_padding_mask.dtype == torch.uint8:
                    key_padding_mask = key_padding_mask.to(torch.bool)
                if key_padding_mask.dtype == torch.bool:
                    key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
                        .reshape(-1, seq_len_k).unsqueeze(1).contiguous()

            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention.backward_diff_qkv(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
                k1.contiguous(), v1.contiguous(),
                dropout_mask.contiguous(), attn_softmax.contiguous(), attn.contiguous(), query.contiguous(),
                key.contiguous(), value.contiguous(),
                in_proj_weight.contiguous(), attn_mask, key_padding_mask, num_heads, dropout, need_softmax_temperature,
                scaleT.contiguous(), attnPre.contiguous(), attn_mask_need_grad or key_mask_need_grad, index
            )
            # grad_in_proj_weight = grad_in_proj_weight.reshape(3, num_heads, embed_dim//num_heads, embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * embed_dim, embed_dim)
            # grad_in_proj_bias = grad_in_proj_bias.reshape(3, num_heads, embed_dim//num_heads).permute(1, 0, 2).contiguous().view(3 * embed_dim)

            grad_attn_mask = None
            if attn_mask_need_grad:
                attn_mask_shape = ctx.attn_mask_shape
                correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                grad_attn_mask = grad_mask
                if attn_mask_shape == correct_3d_size1:
                    grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            if need_softmax_temperature:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 2:
            output, out_proj_weight, out_proj_bias, attnPre, dropout_mask, attn_softmax, query, key, value, in_proj_weight, in_proj_bias, attn_mask, key_padding_mask, scaleT = ctx.saved_tensors
            if attn_mask.numel() > 0:
                if attn_mask.dtype == torch.uint8:
                    attn_mask = attn_mask.to(torch.bool)
                if attn_mask.dtype == torch.bool:
                    if attn_mask.dim() == 2:
                        attn_mask = attn_mask.unsqueeze(0)
                    elif attn_mask.dim() == 3:
                        correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                        if attn_mask.shape == correct_3d_size1:
                            attn_mask = attn_mask.unsqueeze(0).expand(batch_size, num_heads, seq_len, seq_len_k) \
                                .reshape(batch_size * num_heads, seq_len, seq_len_k)
                attn_mask = attn_mask.contiguous()
            if key_padding_mask.numel() > 0:
                if key_padding_mask.dtype == torch.uint8:
                    key_padding_mask = key_padding_mask.to(torch.bool)
                if key_padding_mask.dtype == torch.bool:
                    key_padding_mask = key_padding_mask.unsqueeze(1).expand(batch_size, num_heads, seq_len_k) \
                        .reshape(-1, seq_len_k).unsqueeze(1).contiguous()

            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_scaleT, grad_mask = attention_compact.backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(),
                out_proj_bias.contiguous(), dropout_mask.contiguous(), attn_softmax.contiguous(),
                query.contiguous(), key.contiguous(), value.contiguous(), in_proj_weight.contiguous(),
                in_proj_bias.contiguous(), attn_mask, key_padding_mask,
                num_heads, dropout, need_softmax_temperature, scaleT.contiguous(), attnPre.contiguous(),
                attn_mask_need_grad or key_mask_need_grad, index
            )

            grad_attn_mask = None
            if attn_mask_need_grad:
                attn_mask_shape = ctx.attn_mask_shape
                correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                grad_attn_mask = grad_mask
                if attn_mask_shape == correct_3d_size1:
                    grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)

            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            if need_softmax_temperature:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_scaleT, None, None
            else:
                return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 3:
            output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask = ctx.saved_tensors
            run_attn = ctx.run_attn
            seeds = ctx.seeds
            grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias = run_attn.mha_backward_same_qkv(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
                query.contiguous(), in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
                key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout, seeds)

            grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
                                                              embed_dim).transpose(0, 1).contiguous().view(
                3 * embed_dim, embed_dim)
            grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
                                                                                                          1).contiguous().view(
                3 * embed_dim)
            if batch_first:
                grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
            return grad_input, None, None, None, None, None, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 4:
            output, query, key, value, q1, k1, v1, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask = ctx.saved_tensors
            run_attn = ctx.run_attn
            seeds = ctx.seeds

            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias = run_attn.mha_backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
                k1.contiguous(), v1.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(),
                in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
                key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout, seeds)
            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            return grad_query, grad_key, grad_value, None, None, None, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, None, None, None

        elif attn_method == 5:
            output, query, qkv, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask, scaleT = ctx.saved_tensors
            seeds = ctx.seeds
            if attn_mask.numel() > 0:
                if attn_mask.dim() == 2:
                    attn_mask = attn_mask.unsqueeze(0)
            grad_input, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_mask, grad_scaleT = fused_mha.mha_backward_same_qkv(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), qkv.contiguous(),
                query.contiguous(), in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
                key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout,
                attn_mask_need_grad or key_mask_need_grad,
                need_softmax_temperature, scaleT, seeds)
            if num_heads != 1:
                grad_in_proj_weight = grad_in_proj_weight.reshape(num_heads, 3, embed_dim // num_heads,
                                                                  embed_dim).transpose(0, 1).contiguous().view(
                    3 * embed_dim, embed_dim)
                grad_in_proj_bias = grad_in_proj_bias.reshape(num_heads, 3, embed_dim // num_heads).transpose(0,
                                                                                                              1).contiguous().view(
                    3 * embed_dim)
            grad_attn_mask = None
            if attn_mask_need_grad:
                correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                grad_attn_mask = grad_mask
                if attn_mask.shape == correct_3d_size1:
                    grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
            grad_t = None
            if need_softmax_temperature:
                grad_t = grad_scaleT.reshape(num_heads, 1, 1)
            if batch_first:
                grad_input = grad_input.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
            return grad_input, None, None, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_t, None, None

        elif attn_method == 6:
            output, query, key, value, q1, k1, v1, softmax_sum, softmax_max, in_proj_weight, out_proj_weight, key_padding_mask, attn_mask, scaleT = ctx.saved_tensors
            seeds = ctx.seeds
            if attn_mask.numel() > 0:
                if attn_mask.dim() == 2:
                    attn_mask = attn_mask.unsqueeze(0)
            grad_query, grad_key, grad_value, grad_in_proj_weight, grad_in_proj_bias, grad_out_proj_weight, grad_out_proj_bias, grad_mask, grad_scaleT = fused_mha.mha_backward(
                grad_output2.contiguous(), output.contiguous(), out_proj_weight.contiguous(), q1.contiguous(),
                k1.contiguous(), v1.contiguous(), query.contiguous(), key.contiguous(), value.contiguous(),
                in_proj_weight.contiguous(), softmax_sum.contiguous(), softmax_max.contiguous(),
                key_padding_mask.contiguous(), attn_mask.contiguous(), num_heads, dropout,
                attn_mask_need_grad or key_mask_need_grad,
                need_softmax_temperature, scaleT, seeds)

            grad_attn_mask = None
            if attn_mask_need_grad:
                correct_3d_size1 = (num_heads, seq_len, seq_len_k)
                grad_attn_mask = grad_mask
                if attn_mask.shape == correct_3d_size1:
                    grad_attn_mask = grad_attn_mask.reshape(batch_size, num_heads, seq_len, seq_len_k)
            grad_key_mask = None
            if key_mask_need_grad:
                grad_key_mask = grad_mask.reshape(batch_size, num_heads * seq_len, seq_len_k).sum(dim=1)
            grad_t = None
            if need_softmax_temperature:
                grad_t = grad_scaleT.reshape(num_heads, 1, 1)
            if batch_first:
                grad_query = grad_query.transpose(0, 1).contiguous().view(batch_size, seq_len, embed_dim)
                grad_key = grad_key.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
                grad_value = grad_value.transpose(0, 1).contiguous().view(batch_size, seq_len_k, embed_dim)
            return grad_query, grad_key, grad_value, grad_key_mask, None, grad_attn_mask, None, None, grad_in_proj_weight, grad_in_proj_bias, None, None, None, None, grad_out_proj_weight, grad_out_proj_bias, None, None, None, grad_t, None, None


class MyLinear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        super(MyLinear, self).__init__()
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs))


[docs]class MultiheadAttention(nn.MultiheadAttention): """ 更高效的MultiheadAttention算子 目前支持 ``qkv`` 的 ``embed_dim`` 相同且 ``add_zero_attn=False``,其余和 `PyTorch的MultiheadAttention <https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=multihead#torch.nn.MultiheadAttention>`_ 算子一致 .. note:: 支持两种额外扩展模式 1) 低精度模式,进一步提升性能 指定 ``hfai.nn.context.SetAttnAllowConversion(True)`` 时, MultiheadAttention 在保证误差的情况下对于部分矩阵乘法降低精度进行运算, 进一步提升性能 2) 低显存模式,适应 batch_size 或者参数更大的训练 设置 MultiheadAttention 中 compact 参数为True (default: False) Examples: .. code-block:: python attn = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1).cuda() # 低显存模式 attn_compact = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1, compact=True).cuda() query0 = torch.randn(5, 4, 10).cuda() key0 = torch.randn(3, 4, 10).cuda() value0 = torch.randn(3, 4, 10).cuda() output0 = attn(query0, key0, value0)[0] # 低精度模式 hfai.nn.context.SetAttnAllowConversion(True) query1 = torch.randn(5, 4, 10).cuda() key1 = torch.randn(3, 4, 10).cuda() value1 = torch.randn(3, 4, 10).cuda() output1 = attn(query1, key1, value1)[0] hfai.nn.context.SetRnnAllowConversion(False) """ bias_k: Optional[torch.Tensor] bias_v: Optional[torch.Tensor] def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, need_softmax_temperature=False, compact=False, device=None, dtype=None): super(nn.MultiheadAttention, self).__init__() assert no_attention == False and no_attention_half == False, '未找到hfai attention,请联系HFAI' factory_kwargs = {'device': device, 'dtype': dtype} self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim self.batch_first = batch_first self.num_heads = num_heads self.dropout = dropout self.head_dim = embed_dim // num_heads self.compact = compact assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" if self._qkv_same_embed_dim is False: self.q_proj_weight = nn.Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) self.k_proj_weight = nn.Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) self.v_proj_weight = nn.Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) self.register_parameter('in_proj_weight', None) else: self.in_proj_weight = nn.Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) self.register_parameter('q_proj_weight', None) self.register_parameter('k_proj_weight', None) self.register_parameter('v_proj_weight', None) if bias: self.in_proj_bias = nn.Parameter(torch.empty((3 * embed_dim), **factory_kwargs)) else: self.register_parameter('in_proj_bias', None) self.out_proj = MyLinear(embed_dim, embed_dim, **factory_kwargs) if add_bias_kv: self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) else: self.bias_k = self.bias_v = None # scaleT = 1.0 / temperature self.need_softmax_temperature = need_softmax_temperature if need_softmax_temperature: self.scaleT = nn.Parameter(torch.ones((num_heads, 1, 1), **factory_kwargs)) else: self.scaleT = torch.ones(1, **factory_kwargs) self.add_zero_attn = add_zero_attn self._reset_parameters() def _reset_parameters(self): kaiming_uniform_(self.out_proj.weight, a=math.sqrt(5)) if self.out_proj.bias is not None: fan_in, _ = _calculate_fan_in_and_fan_out(self.out_proj.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 uniform_(self.out_proj.bias, -bound, bound) if self._qkv_same_embed_dim: xavier_uniform_(self.in_proj_weight) # self.in_proj_weight.data = self.in_proj_weight.data.reshape(3, self.num_heads, self.embed_dim//self.num_heads, self.embed_dim).permute(1, 0, 2, 3).contiguous().view(3 * self.embed_dim, self.embed_dim) else: xavier_uniform_(self.q_proj_weight) xavier_uniform_(self.k_proj_weight) xavier_uniform_(self.v_proj_weight) if self.in_proj_bias is not None: constant_(self.in_proj_bias, 0.) constant_(self.out_proj.bias, 0.) if self.bias_k is not None: xavier_normal_(self.bias_k) if self.bias_v is not None: xavier_normal_(self.bias_v) def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None): if not (query.is_cuda and query.dtype == torch.float32): return super().forward(query, key, value, key_padding_mask, need_weights, attn_mask) return MultiheadAttentionFunc.apply(query, key, value, key_padding_mask, need_weights, attn_mask, self.embed_dim, self.num_heads, self.in_proj_weight, self.in_proj_bias, self.bias_k, self.bias_v, self.add_zero_attn, self.dropout if self.training else 0.0, self.out_proj.weight, self.out_proj.bias, self._qkv_same_embed_dim, self.batch_first, self.need_softmax_temperature, self.scaleT, self.compact, torch.is_grad_enabled() and self.training) def inference_init(self, max_seq_len): ... def inference_increment(self, qkv: Tensor, max_seq_len: int, now_seq_len: int, cache: Optional[Tensor] = None) -> Union[Tensor, Tuple[Tensor, Tensor]]: """ 高效的 attention 增量 inference 函数 Examples: .. code-block:: python mha = hfai.nn.MultiheadAttention(128, 4).cuda() input = torch.randn(64, 100, 128).cuda() output, cache = mha.inference_increment(input, max_seq_len=128, now_seq_len=64) for _ in range(64): input = torch.randn(1, 100, 128).cuda() output, cache = mha.inference_increment(input, max_seq_len=128, now_seq_len=64, cache=cache) """ if not self.batch_first: qkv = qkv.transpose(1, 0) batch_size, seq_len, embed_dim = qkv.shape head_dim = embed_dim // self.num_heads if cache is None: # the first inference cache = torch.empty(batch_size, max_seq_len, embed_dim, dtype=qkv.dtype, device=qkv.device) cache[:, now_seq_len - seq_len:now_seq_len, :] = qkv wq, wk, wv = self.in_proj_weight.chunk(3, dim=0) bq, bk, bv = self.in_proj_bias.chunk(3, dim=0) q = F.linear(qkv[:, -1:, :], wq, bq) q = q.view(batch_size, self.num_heads, 1, head_dim) q = q.contiguous() k = F.linear(cache[:, :now_seq_len, :], wk, bk) k = k.view(batch_size, now_seq_len, self.num_heads, head_dim).transpose(1, 2) attn = q @ k.transpose(-2, -1) del q, k attn.mul_(1.0 / math.sqrt(head_dim)) if self.need_softmax_temperature: attn.mul_(self.scaleT) attn = F.softmax(attn, dim=-1) v = F.linear(cache[:, :now_seq_len, :], wv, bv) v = v.view(batch_size, now_seq_len, self.num_heads, head_dim).transpose(1, 2) out = attn @ v del attn, v out = out.reshape(batch_size, 1, embed_dim) # [batch_size, 1, embed_dim] out = F.linear(out, self.out_proj.weight, self.out_proj.bias) # [batch_size, 1, embed_dim] if not self.batch_first: out = out.transpose(0, 1) return out, cache