Shortcuts

Source code for hfai.nn.modules.gru.GRU

from typing import Union, Tuple, Optional
import numbers
from torch import nn
from torch.autograd import Function
import torch
from torch.nn.utils.rnn import PackedSequence
from ..dropout import Dropout

try:
    import hfcuda.hf_a100_gru_cuda_onchip_tf as hf_a100_gru_cuda_onchip_tf
except:
    try:
        import hfai.hfcuda.hf_a100_gru_cuda_onchip_tf as hf_a100_gru_cuda_onchip_tf
    except:
        pass
try:
    import hfcuda.hf_a100_gru_cuda_offchip as hf_a100_gru_cuda_offchip
except:
    try:
        import hfai.hfcuda.hf_a100_gru_cuda_offchip as hf_a100_gru_cuda_offchip
    except:
        pass


def get_GRU(bs, hidden_size, device):
    if torch.cuda.get_device_capability() != (
            8, 0
    ) or torch.cuda.get_device_properties(device).multi_processor_count < 108:
        return hf_a100_gru_cuda_offchip
    if bs <= 64 and hidden_size <= 1728:
        try:
            return hf_a100_gru_cuda_onchip_tf
        except:
            return hf_a100_gru_cuda_offchip
    else:
        return hf_a100_gru_cuda_offchip


class GRUFunction(Function):
    @staticmethod
    @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
    def forward(ctx, x, weight_ih, weight_hh, bias_ih, bias_hh, h_0, training):
        ctx.GRU = get_GRU(x.size()[1], weight_hh.size()[1], x.device)
        if training:
            h_1, y, linear_gates = ctx.GRU.forward(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0)
            ctx.save_for_backward(x, weight_ih, weight_hh, h_0, y, linear_gates)
            return h_1, y.narrow(0, 1, x.size(0))
        else:
            h_1, y = ctx.GRU.forward_infer(
                x, weight_ih, weight_hh, bias_ih, bias_hh, h_0)
            return h_1, y

    @staticmethod
    @torch.cuda.amp.custom_bwd
    @torch.autograd.function.once_differentiable
    def backward(ctx, dh_1, dh_layer):
        dh_1 = dh_1.contiguous()
        dh_layer = dh_layer.contiguous()
        variables = ctx.saved_tensors
        dx, dh_0, dweight_ih, dweight_hh, dbias_ih, dbias_hh = ctx.GRU.backward(
            *variables, dh_layer, dh_1)
        ret = dx, dweight_ih, dweight_hh, dbias_ih, dbias_hh, dh_0, None
        return ret


[docs]class GRU(nn.GRU): """ 高效的 GRU 算子 使用方式与 `PyTorch 的 GRU 算子 <https://pytorch.org/docs/stable/generated/torch.nn.GRU.html?highlight=lstm#torch.nn.GRU>`_ 一致 .. note:: 额外支持 ``drop_connect`` 参数. 如果 ``0 < drop_connect <= 1``, 会在所有的 ``weight_hh`` 后面紧接着增加一层 ``Dropout(p=drop_connect)`` .. note:: 支持 2 种精度模式: 1) TF32 模式 (默认): GRU 中的矩阵乘法使用 TF32 加速 当 ``batch_size <= 64 && hidden_size <= 1728`` 时, GRU 使用 persistent 方法加速 2) Float32 模式: GRU 中的矩阵乘法使用完整精度 需要指定 ``torch.backends.cuda.matmul.allow_tf32 = False`` .. note:: ``hidden_size`` 是 64 的倍数时性能最好 Examples: .. code-block:: python gru = hfai.nn.GRU(input_size=10, hidden_size=20).cuda() input0 = torch.randn(5, 100, 10).cuda() output, hn = gru(input0, None) # TF32 模式, 不使用 persistent 方法 input2 = torch.randn(5, 8, 10).cuda() output, hn = gru(input2, None) # TF32 模式, 使用 persistent 方法 """ def __init__(self, *args, **kwargs): drop_connect = kwargs.pop('drop_connect', 0) if not isinstance(drop_connect, numbers.Number) or not 0 <= drop_connect <= 1 or \ isinstance(drop_connect, bool): raise ValueError("dropout should be a number in range [0, 1] " "representing the probability of an element being " "zeroed") super().__init__(*args, **kwargs) if self.dropout != 0: self.drop = Dropout(self.dropout) else: self.drop = None if drop_connect != 0: self.drop_connect = Dropout(drop_connect) else: self.drop_connect = None def forward(self, input: Union[torch.Tensor, PackedSequence], hx: Optional[torch.Tensor] = None): if not input.is_cuda: return super().forward(input, hx) if isinstance(input, PackedSequence): return super().forward(input, hx) bs = input.size()[0] if self.batch_first else input.size()[1] D = 2 if self.bidirectional else 1 if hx is None: h = torch.zeros(self.num_layers * D, bs, self.hidden_size, device=input.device) else: h = hx h = h.contiguous() self.check_forward_args(input, h, None) if self.batch_first: input = input.transpose(0, 1) input = input.contiguous() seq = input.size()[0] if not self.bias: bias = torch.zeros(3 * self.hidden_size, device=input.device) h_out = [] y = input for i in range(self.num_layers): h1, y1 = GRUFunction.apply( y, self.__getattr__(f'weight_ih_l{i}'), self.__getattr__(f'weight_hh_l{i}') if self.drop_connect is None else self.drop_connect( self.__getattr__(f'weight_hh_l{i}')), self.__getattr__(f'bias_ih_l{i}') if self.bias else bias, self.__getattr__(f'bias_hh_l{i}') if self.bias else bias, h[i * D], self.training and torch.is_grad_enabled()) h_out.append(h1) if self.bidirectional: h1_reverse, y1_reverse = GRUFunction.apply( y if seq == 1 else y.flip(0), self.__getattr__(f'weight_ih_l{i}_reverse'), self.__getattr__(f'weight_hh_l{i}_reverse') if self.drop_connect is None else self.drop_connect( self.__getattr__(f'weight_hh_l{i}_reverse')), self.__getattr__(f'bias_ih_l{i}_reverse') if self.bias else bias, self.__getattr__(f'bias_hh_l{i}_reverse') if self.bias else bias, h[i * 2 + 1], self.training and torch.is_grad_enabled()) h_out.append(h1_reverse) y1 = torch.cat( (y1, y1_reverse if seq == 1 else y1_reverse.flip(0)), 2) y = y1 if self.drop is not None and i != self.num_layers - 1: y = self.drop(y) return y.transpose(0, 1) if self.batch_first else y, torch.stack( h_out, 0)