Shortcuts

Source code for hfai.nn.modules.rnn

from typing import Union, Tuple, Optional
import numbers
import torch
from torch import nn
from torch.autograd import Function
from torch.nn.utils.rnn import PackedSequence
from .context import context
from .dropout import Dropout

try:
    import hfai.hfcuda.hf_a100_lstm_cuda_onchip_fp as hf_a100_lstm_cuda_onchip_fp
except:
    pass
try:
    import hfai.hfcuda.hf_a100_lstm_cuda_onchip_tf as hf_a100_lstm_cuda_onchip_tf
except:
    pass
try:
    import hfai.hfcuda.hf_a100_lstm_cuda_onchip_tf_small_h as hf_a100_lstm_cuda_onchip_tf_small_h
except:
    pass
try:
    import hfai.hfcuda.hf_a100_lstm_cuda_onchip_bf16 as hf_a100_lstm_cuda_onchip_bf16
except:
    pass
try:
    import hfai.hfcuda.hf_a100_lstm_cuda_offchip as hf_a100_lstm_cuda_offchip
except:
    pass


def get_LSTM(bs, hidden_size, device):
    if torch.cuda.get_device_capability() != (
            8, 0
    ) or torch.cuda.get_device_properties(device).multi_processor_count < 108:
        return hf_a100_lstm_cuda_offchip
    elif not torch.backends.cuda.matmul.allow_tf32:
        if bs <= 16 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_fp
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip
    elif context.GetRnnAllowConversion():
        if bs <= 72 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_bf16
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip
    else:
        if bs <= 64 and hidden_size <= 1728:
            try:
                return hf_a100_lstm_cuda_onchip_tf
            except:
                return hf_a100_lstm_cuda_offchip
        elif bs <= 512 and hidden_size <= 512:
            try:
                return hf_a100_lstm_cuda_onchip_tf_small_h
            except:
                return hf_a100_lstm_cuda_offchip
        else:
            return hf_a100_lstm_cuda_offchip


class LSTMFunction(Function):
    @staticmethod
    def forward(ctx, x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0, training):
        ctx.LSTM = get_LSTM(x.size()[1], weight_hh.size()[1], x.device)
        if training:
            h_1, c_1, y, cells, linear_gates = ctx.LSTM.forward(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0)
            ctx.save_for_backward(x, weight_ih, weight_hh, h_0, c_0, y, cells, linear_gates)
            return h_1, c_1, y.narrow(0, 1, x.size(0))
        else:
            h_1, c_1, y = ctx.LSTM.forward_infer(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0)
            return h_1, c_1, y

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, dh_1, dc_1, dh_layer):
        dh_1 = dh_1.contiguous()
        dc_1 = dc_1.contiguous()
        dh_layer = dh_layer.contiguous()
        variables = ctx.saved_tensors
        dx, dh_0, dc_0, dweight_ih, dweight_hh, dbias_ih, dbias_hh = ctx.LSTM.backward(
            *variables, dh_layer, dh_1, dc_1, None)
        ret = dx, dweight_ih, dweight_hh, dbias_ih, dbias_hh, dh_0, dc_0, None
        return ret


[docs]class LSTM(nn.LSTM): """ 高效的 LSTM 算子 使用方式与 `PyTorch 的 LSTM 算子 <https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM>`_ 一致 不支持 proj_size 参数 .. note:: 额外支持 ``drop_connect`` 参数. 如果 ``0 < drop_connect <= 1``, 会在所有的 ``weight_hh`` 后面紧接着增加一层 ``Dropout(p=drop_connect)`` .. note:: 支持 3 种精度模式: 1) TF32 模式 (默认): LSTM 中的矩阵乘法使用 TF32 加速 当 ``batch_size <= 64 && hidden_size <= 1728`` 或 ``batch_size <= 512 && hidden_size <= 512`` 时, LSTM 使用 persistent 方法加速 2) Float32 模式: LSTM 中的矩阵乘法使用完整精度 需要指定 ``torch.backends.cuda.matmul.allow_tf32 = False`` 当 ``batch_size <= 16 && hidden_size <= 1728`` 时, LSTM 使用 persistent 方法加速 3) BFloat16 模式: LSTM 中的矩阵乘法使用 BFloat16 加速 需要指定 ``hfai.nn.context.SetRnnAllowConversion(True)`` 且 ``batch_size <= 72 && hidden_size <= 1728`` .. note:: ``hidden_size`` 是 64 的倍数时性能最好 Examples: .. code-block:: python lstm = hfai.nn.LSTM(input_size=10, hidden_size=20).cuda() input0 = torch.randn(5, 100, 10).cuda() output, (hn, cn) = lstm(input0, None) # TF32 模式, 不使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(True) input1 = torch.randn(5, 64, 10).cuda() output, (hn, cn) = lstm(input1, None) # BFloat16 模式, 使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(False) input2 = torch.randn(5, 8, 10).cuda() output, (hn, cn) = lstm(input2, None) # TF32 模式, 使用 persistent 方法 """ def __init__(self, *args, **kwargs): drop_connect = kwargs.pop('drop_connect', 0) if not isinstance(drop_connect, numbers.Number) or not 0 <= drop_connect <= 1 or isinstance(drop_connect, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") super().__init__(*args, **kwargs) if self.dropout != 0: self.drop = Dropout(self.dropout) else: self.drop = None if drop_connect != 0: self.drop_connect = Dropout(drop_connect) else: self.drop_connect = lambda x: x def forward(self, input: Union[torch.Tensor, PackedSequence], hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): if not input.is_cuda: return super().forward(input, hx) if isinstance(input, PackedSequence): return super().forward(input, hx) if input.dtype != torch.float: return super().forward(input, hx) bs = input.size()[0] if self.batch_first else input.size()[1] D = 2 if self.bidirectional else 1 if hx is None: h = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) c = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) else: h, c = hx h = h.contiguous() c = c.contiguous() self.check_forward_args(input, (h, c), None) if self.batch_first: input = input.transpose(0, 1) input = input.contiguous() seq = input.size()[0] if not self.bias: bias = torch.zeros(4 * self.hidden_size, device=input.device) h_out = [] c_out = [] y = input for i in range(self.num_layers): h1, c1, y1 = LSTMFunction.apply( y, getattr(self, f'weight_ih_l{i}'), self.drop_connect(getattr(self, f'weight_hh_l{i}')), getattr(self, f'bias_ih_l{i}') if self.bias else bias, getattr(self, f'bias_hh_l{i}') if self.bias else bias, h[i * D], c[i * D], self.training and torch.is_grad_enabled()) h_out.append(h1) c_out.append(c1) if self.bidirectional: h1_reverse, c1_reverse, y1_reverse = LSTMFunction.apply( y if seq == 1 else y.flip(0), getattr(self, f'weight_ih_l{i}_reverse'), self.drop_connect(getattr(self, f'weight_hh_l{i}_reverse')), getattr(self, f'bias_ih_l{i}_reverse') if self.bias else bias, getattr(self, f'bias_hh_l{i}_reverse') if self.bias else bias, h[i * 2 + 1], c[i * 2 + 1], self.training and torch.is_grad_enabled()) h_out.append(h1_reverse) c_out.append(c1_reverse) y1 = torch.cat((y1, y1_reverse if seq == 1 else y1_reverse.flip(0)), 2) y = y1 if self.drop is not None and i != self.num_layers - 1: y = self.drop(y) return y.transpose(0, 1) if self.batch_first else y, \ (torch.stack(h_out, 0), torch.stack(c_out, 0))
class LSTM_fullcFunction(Function): @staticmethod def forward(ctx, x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0, training): ctx.LSTM = get_LSTM(x.size()[1], weight_hh.size()[1], x.device) if training: h_1, c_1, y, cells, linear_gates = ctx.LSTM.forward( x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0) ctx.save_for_backward(x, weight_ih, weight_hh, h_0, c_0, y, cells, linear_gates) return h_1, c_1, y.narrow(0, 1, x.size(0)), cells else: h_1, c_1, y, cells, linear_gates = ctx.LSTM.forward( x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, c_0) return h_1, c_1, y.narrow(0, 1, x.size(0)), cells @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, dh_1, dc_1, dh_layer, dcells): dh_1 = dh_1.contiguous() dc_1 = dc_1.contiguous() dh_layer = dh_layer.contiguous() dcells = dcells.contiguous() variables = ctx.saved_tensors dx, dh_0, dc_0, dweight_ih, dweight_hh, dbias_ih, dbias_hh = ctx.LSTM.backward( *variables, dh_layer, dh_1, dc_1, dcells) ret = dx, dweight_ih, dweight_hh, dbias_ih, dbias_hh, dh_0, dc_0, None return ret
[docs]class LSTM_fullc(nn.LSTM): """ 高效的 LSTM 算子,并输出完整的 c 模型参数和 Inputs 与 `PyTorch 的 LSTM 算子 <https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html?highlight=lstm#torch.nn.LSTM>`_ 一致 不支持 proj_size 参数 Outputs: output, (h_n, c_n), full_c * **output**: 与 PyTorch 一致 * **h_n**: 与 PyTorch 一致 * **c_n**: 与 PyTorch 一致 * **full_c**: $(seq_len, D * num_layers, batch_size, hidden_size)$, 包含了完整的共 seq_len 层的 c .. note:: 额外支持 ``drop_connect`` 参数. 如果 ``0 < drop_connect <= 1``, 会在所有的 ``weight_hh`` 后面紧接着增加一层 ``Dropout(p=drop_connect)`` .. note:: 支持 3 种精度模式: 1) TF32 模式 (默认): LSTM 中的矩阵乘法使用 TF32 加速 当 ``batch_size <= 64 && hidden_size <= 1728`` 时, LSTM 使用 persistent 方法加速 2) Float32 模式: LSTM 中的矩阵乘法使用完整精度 需要指定 ``torch.backends.cuda.matmul.allow_tf32 = False`` 当 ``batch_size <= 16 && hidden_size <= 1728`` 时, LSTM 使用 persistent 方法加速 3) BFloat16 模式: LSTM 中的矩阵乘法使用 BFloat16 加速 需要指定 ``hfai.nn.context.SetRnnAllowConversion(True)`` 且 ``batch_size <= 72 && hidden_size <= 1728`` .. note:: ``hidden_size`` 是 64 的倍数时性能最好 Examples: .. code-block:: python lstm_fullc = hfai.nn.LSTM_fullc(input_size=10, hidden_size=20).cuda() input0 = torch.randn(5, 100, 10).cuda() output, (hn, cn), full_c = lstm_fullc(input0, None) # TF32 模式, 不使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(True) input1 = torch.randn(5, 64, 10).cuda() output, (hn, cn), full_c = lstm_fullc(input1, None) # BFloat16 模式, 使用 persistent 方法 hfai.nn.context.SetRnnAllowConversion(False) input2 = torch.randn(5, 8, 10).cuda() output, (hn, cn), full_c = lstm_fullc(input2, None) # TF32 模式, 使用 persistent 方法 """ def __init__(self, *args, **kwargs): drop_connect = kwargs.pop('drop_connect', 0) if not isinstance(drop_connect, numbers.Number) or not 0 <= drop_connect <= 1 or isinstance(drop_connect, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") super().__init__(*args, **kwargs) if self.dropout != 0: self.drop = Dropout(self.dropout) else: self.drop = None if drop_connect != 0: self.drop_connect = Dropout(drop_connect) else: self.drop_connect = lambda x: x def forward(self, input: Union[torch.Tensor, PackedSequence], hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None): if not input.is_cuda: return super().forward(input, hx) if isinstance(input, PackedSequence): return super().forward(input, hx) if input.dtype != torch.float: return super().forward(input, hx) bs = input.size()[0] if self.batch_first else input.size()[1] D = 2 if self.bidirectional else 1 if hx is None: h = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) c = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) else: h, c = hx h = h.contiguous() c = c.contiguous() self.check_forward_args(input, (h, c), None) if self.batch_first: input = input.transpose(0, 1) input = input.contiguous() seq = input.size()[0] if not self.bias: bias = torch.zeros(4 * self.hidden_size, device=input.device) h_out = [] c_out = [] cells_out = [] y = input for i in range(self.num_layers): h1, c1, y1, cells = LSTM_fullcFunction.apply( y, getattr(self, f'weight_ih_l{i}'), self.drop_connect(getattr(self, f'weight_hh_l{i}')), getattr(self, f'bias_ih_l{i}') if self.bias else bias, getattr(self, f'bias_hh_l{i}') if self.bias else bias, h[i * D], c[i * D], self.training and torch.is_grad_enabled()) h_out.append(h1) c_out.append(c1) cells_out.append(cells) if self.bidirectional: h1_reverse, c1_reverse, y1_reverse, cells_reverse = LSTM_fullcFunction.apply( y if seq == 1 else y.flip(0), getattr(self, f'weight_ih_l{i}_reverse'), self.drop_connect(getattr(self, f'weight_hh_l{i}_reverse')), getattr(self, f'bias_ih_l{i}_reverse') if self.bias else bias, getattr(self, f'bias_hh_l{i}_reverse') if self.bias else bias, h[i * 2 + 1], c[i * 2 + 1], self.training and torch.is_grad_enabled()) h_out.append(h1_reverse) c_out.append(c1_reverse) cells_out.append(cells_reverse) y1 = torch.cat((y1, y1_reverse if seq == 1 else y1_reverse.flip(0)), 2) y = y1 if self.drop is not None and i != self.num_layers - 1: y = self.drop(y) return y.transpose(0, 1) if self.batch_first else y, \ (torch.stack(h_out, 0), torch.stack(c_out, 0)), torch.stack(cells_out, 1)
try: import hfai.hfcuda.hf_a100_gru_cuda_onchip_tf as hf_a100_gru_cuda_onchip_tf except: pass try: import hfai.hfcuda.hf_a100_gru_cuda_offchip as hf_a100_gru_cuda_offchip except: pass def get_GRU(bs, hidden_size, device): if torch.cuda.get_device_capability() != ( 8, 0 ) or torch.cuda.get_device_properties(device).multi_processor_count < 108: return hf_a100_gru_cuda_offchip elif not torch.backends.cuda.matmul.allow_tf32: return hf_a100_gru_cuda_offchip else: if bs <= 64 and hidden_size <= 1728: try: return hf_a100_gru_cuda_onchip_tf except: return hf_a100_gru_cuda_offchip else: return hf_a100_gru_cuda_offchip class GRUFunction(Function): @staticmethod def forward(ctx, x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, training): ctx.GRU = get_GRU(x.size()[1], weight_hh.size()[1], x.device) if training: h_1, y, linear_gates = ctx.GRU.forward( x, weight_ih, weight_hh, bias_ih, bias_hh, h_0) ctx.save_for_backward(x, weight_ih, weight_hh, h_0, y, linear_gates) return h_1, y.narrow(0, 1, x.size(0)) else: h_1, y = ctx.GRU.forward_infer( x, weight_ih, weight_hh, bias_ih, bias_hh, h_0) return h_1, y @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, dh_1, dh_layer): dh_1 = dh_1.contiguous() dh_layer = dh_layer.contiguous() variables = ctx.saved_tensors dx, dh_0, dweight_ih, dweight_hh, dbias_ih, dbias_hh = ctx.GRU.backward( *variables, dh_layer, dh_1) ret = dx, dweight_ih, dweight_hh, dbias_ih, dbias_hh, dh_0, None return ret
[docs]class GRU(nn.GRU): """ 高效的 GRU 算子 使用方式与 `PyTorch 的 GRU 算子 <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=lstm#torch.nn.GRU>`_ 一致 .. note:: 额外支持 ``drop_connect`` 参数. 如果 ``0 < drop_connect <= 1``, 会在所有的 ``weight_hh`` 后面紧接着增加一层 ``Dropout(p=drop_connect)`` .. note:: 支持 2 种精度模式: 1) TF32 模式 (默认): GRU 中的矩阵乘法使用 TF32 加速 当 ``batch_size <= 64 && hidden_size <= 1728`` 时, GRU 使用 persistent 方法加速 2) Float32 模式: GRU 中的矩阵乘法使用完整精度 需要指定 ``torch.backends.cuda.matmul.allow_tf32 = False`` .. note:: ``hidden_size`` 是 64 的倍数时性能最好 Examples: .. code-block:: python gru = hfai.nn.GRU(input_size=10, hidden_size=20).cuda() input0 = torch.randn(5, 100, 10).cuda() output, hn = gru(input0, None) # TF32 模式, 不使用 persistent 方法 input2 = torch.randn(5, 8, 10).cuda() output, hn = gru(input2, None) # TF32 模式, 使用 persistent 方法 """ def __init__(self, *args, **kwargs): drop_connect = kwargs.pop('drop_connect', 0) if not isinstance(drop_connect, numbers.Number) or not 0 <= drop_connect <= 1 or isinstance(drop_connect, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") super().__init__(*args, **kwargs) if self.dropout != 0: self.drop = Dropout(self.dropout) else: self.drop = None if drop_connect != 0: self.drop_connect = Dropout(drop_connect) else: self.drop_connect = lambda x: x def forward(self, input: Union[torch.Tensor, PackedSequence], hx: Optional[torch.Tensor] = None): if not input.is_cuda: return super().forward(input, hx) if isinstance(input, PackedSequence): return super().forward(input, hx) if input.dtype != torch.float: return super().forward(input, hx) bs = input.size()[0] if self.batch_first else input.size()[1] D = 2 if self.bidirectional else 1 if hx is None: h = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) else: h = hx h = h.contiguous() self.check_forward_args(input, h, None) if self.batch_first: input = input.transpose(0, 1) input = input.contiguous() seq = input.size()[0] if not self.bias: bias = torch.zeros(3 * self.hidden_size, device=input.device) h_out = [] y = input for i in range(self.num_layers): h1, y1 = GRUFunction.apply( y, getattr(self, f'weight_ih_l{i}'), self.drop_connect(getattr(self, f'weight_hh_l{i}')), getattr(self, f'bias_ih_l{i}') if self.bias else bias, getattr(self, f'bias_hh_l{i}') if self.bias else bias, h[i * D], self.training and torch.is_grad_enabled()) h_out.append(h1) if self.bidirectional: h1_reverse, y1_reverse = GRUFunction.apply( y if seq == 1 else y.flip(0), getattr(self, f'weight_ih_l{i}_reverse'), self.drop_connect(getattr(self, f'weight_hh_l{i}_reverse')), getattr(self, f'bias_ih_l{i}_reverse') if self.bias else bias, getattr(self, f'bias_hh_l{i}_reverse') if self.bias else bias, h[i * 2 + 1], self.training and torch.is_grad_enabled()) h_out.append(h1_reverse) y1 = torch.cat((y1, y1_reverse if seq == 1 else y1_reverse.flip(0)), 2) y = y1 if self.drop is not None and i != self.num_layers - 1: y = self.drop(y) return y.transpose(0, 1) if self.batch_first else y, \ torch.stack(h_out, 0)