Shortcuts

Source code for hfai.nn.modules.lstm.LSTM

from typing import Union, Tuple, Optional
import numbers
from torch import nn
from torch.autograd import Function
import torch
from torch.nn.utils.rnn import PackedSequence
from ..context import context
from ..dropout import Dropout

try:
    import hfcuda.hf_a100_lstm_cuda_onchip_fp as hf_a100_lstm_cuda_onchip_fp
except:
    try:
        import hfai.hfcuda.hf_a100_lstm_cuda_onchip_fp as hf_a100_lstm_cuda_onchip_fp
    except:
        pass
try:
    import hfcuda.hf_a100_lstm_cuda_onchip_tf as hf_a100_lstm_cuda_onchip_tf
except:
    try:
        import hfai.hfcuda.hf_a100_lstm_cuda_onchip_tf as hf_a100_lstm_cuda_onchip_tf
    except:
        pass
try:
    import hfcuda.hf_a100_lstm_cuda_onchip_bf16 as hf_a100_lstm_cuda_onchip_bf16
except:
    try:
        import hfai.hfcuda.hf_a100_lstm_cuda_onchip_bf16 as hf_a100_lstm_cuda_onchip_bf16
    except:
        pass
try:
    import hfcuda.hf_a100_lstm_cuda_offchip as hf_a100_lstm_cuda_offchip
except:
    try:
        import hfai.hfcuda.hf_a100_lstm_cuda_offchip as hf_a100_lstm_cuda_offchip
    except:
        pass


def get_LSTM(bs, hidden_size, device):
    if torch.cuda.get_device_capability() != (
            8, 0
    ) or torch.cuda.get_device_properties(device).multi_processor_count < 108:
        return hf_a100_lstm_cuda_offchip
    elif not torch.backends.cuda.matmul.allow_tf32:
        if bs <= 16 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_fp
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip
    elif context.GetRnnAllowConversion():
        if bs <= 72 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_bf16
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip
    else:
        if bs <= 64 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_tf
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip


class LSTMFunction(Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0, training):
        ctx.LSTM = get_LSTM(x.size()[1], weight_hh.size()[1], x.device)
        if training:
            h_1, c_1, y, cells, linear_gates = ctx.LSTM.forward(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0)
            ctx.save_for_backward(x, weight_ih, weight_hh, h_0, c_0, y, cells, linear_gates)
            return h_1, c_1, y.narrow(0, 1, x.size(0))
        else:
            h_1, c_1, y = ctx.LSTM.forward_infer(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0)
            return h_1, c_1, y

    @staticmethod
    @torch.cuda.amp.custom_bwd
    @torch.autograd.function.once_differentiable
    def backward(ctx, dh_1, dc_1, dh_layer):
        dh_1 = dh_1.contiguous()
        dc_1 = dc_1.contiguous()
        dh_layer = dh_layer.contiguous()
        variables = ctx.saved_tensors
        dx, dh_0, dc_0, dweight_ih, dweight_hh, dbias_ih, dbias_hh = ctx.LSTM.backward(
            *variables, dh_layer, dh_1, dc_1, None)
        ret = dx, dweight_ih, dweight_hh, dbias_ih, dbias_hh, dh_0, dc_0, None
        return ret


[docs]class LSTM(nn.LSTM): """ 高效的 LSTM 算子 使用方式与 `PyTorch 的 LSTM 算子 <https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM>`_ 一致 不支持 proj_size 参数 .. note:: 额外支持 ``drop_connect`` 参数. 如果 ``0 < drop_connect <= 1``, 会在所有的 ``weight_hh`` 后面紧接着增加一层 ``Dropout(p=drop_connect)`` .. note:: 支持 3 种精度模式: 1) TF32 模式 (默认): LSTM 中的矩阵乘法使用 TF32 加速 当 ``batch_size <= 64 && hidden_size <= 1728`` 时, LSTM 使用 persistent 方法加速 2) Float32 模式: LSTM 中的矩阵乘法使用完整精度 需要指定 ``torch.backends.cuda.matmul.allow_tf32 = False`` 当 ``batch_size <= 16 && hidden_size <= 1728`` 时, LSTM 使用 persistent 方法加速 3) BFloat16 模式: LSTM 中的矩阵乘法使用 BFloat16 加速 需要指定 ``hfai.nn.context.SetRnnAllowConversion(True)`` 且 ``batch_size <= 72 && hidden_size <= 1728`` .. note:: ``hidden_size`` 是 64 的倍数时性能最好 Examples: .. code-block:: python lstm = hfai.nn.LSTM(input_size=10, hidden_size=20).cuda() input0 = torch.randn(5, 100, 10).cuda() output, (hn, cn) = lstm(input0, None) # TF32 模式, 不使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(True) input1 = torch.randn(5, 64, 10).cuda() output, (hn, cn) = lstm(input1, None) # BFloat16 模式, 使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(False) input2 = torch.randn(5, 8, 10).cuda() output, (hn, cn) = lstm(input2, None) # TF32 模式, 使用 persistent 方法 """ def __init__(self, *args, **kwargs): drop_connect = kwargs.pop('drop_connect', 0) if not isinstance(drop_connect, numbers.Number) or not 0 <= drop_connect <= 1 or \ isinstance(drop_connect, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") super().__init__(*args, **kwargs) if self.dropout != 0: self.drop = Dropout(self.dropout) else: self.drop = None if drop_connect != 0: self.drop_connect = Dropout(drop_connect) else: self.drop_connect = None def forward(self, input: Union[torch.Tensor, PackedSequence], hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): if not input.is_cuda: return super().forward(input, hx) if isinstance(input, PackedSequence): return super().forward(input, hx) bs = input.size()[0] if self.batch_first else input.size()[1] D = 2 if self.bidirectional else 1 if hx is None: h = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) c = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) else: h, c = hx h = h.contiguous() c = c.contiguous() self.check_forward_args(input, (h, c), None) if self.batch_first: input = input.transpose(0, 1) input = input.contiguous() seq = input.size()[0] if not self.bias: bias = torch.zeros(4 * self.hidden_size, device=input.device) h_out = [] c_out = [] y = input for i in range(self.num_layers): h1, c1, y1 = LSTMFunction.apply( y, self.__getattr__(f'weight_ih_l{i}'), self.__getattr__(f'weight_hh_l{i}') if self.drop_connect is None else self.drop_connect( self.__getattr__(f'weight_hh_l{i}')), self.__getattr__(f'bias_ih_l{i}') if self.bias else bias, self.__getattr__(f'bias_hh_l{i}') if self.bias else bias, h[i * D], c[i * D], self.training and torch.is_grad_enabled()) h_out.append(h1) c_out.append(c1) if self.bidirectional: h1_reverse, c1_reverse, y1_reverse = LSTMFunction.apply( y if seq == 1 else y.flip(0), self.__getattr__(f'weight_ih_l{i}_reverse'), self.__getattr__(f'weight_hh_l{i}_reverse') if self.drop_connect is None else self.drop_connect( self.__getattr__(f'weight_hh_l{i}_reverse')), self.__getattr__(f'bias_ih_l{i}_reverse') if self.bias else bias, self.__getattr__(f'bias_hh_l{i}_reverse') if self.bias else bias, h[i * 2 + 1], c[i * 2 + 1], self.training and torch.is_grad_enabled()) h_out.append(h1_reverse) c_out.append(c1_reverse) y1 = torch.cat( (y1, y1_reverse if seq == 1 else y1_reverse.flip(0)), 2) y = y1 if self.drop is not None and i != self.num_layers - 1: y = self.drop(y) return y.transpose(0, 1) if self.batch_first else y, (torch.stack( h_out, 0), torch.stack(c_out, 0))