Shortcuts

Source code for hfai.nn.modules.convert

from .fast_multihead_attention import *
from .hf_layernorm import *
from .lstm import *
from .gru import *
from .dropout import *
from .activation import *
import torch
import copy

hfai_list = [(torch.nn.MultiheadAttention, MultiheadAttention_to_hfai),
             (torch.nn.LayerNorm, LayerNorm_to_hfai),
             (torch.nn.LSTM, LSTM_to_hfai),
             (torch.nn.GRU, GRU_to_hfai),
             (torch.nn.Dropout, Dropout_to_hfai),
             (torch.nn.ReLU, ReLU_to_hfai),
             (torch.nn.Hardtanh, Hardtanh_to_hfai),
             (torch.nn.ReLU6, ReLU6_to_hfai),
             (torch.nn.Softplus, Softplus_to_hfai),
             (torch.nn.Softmin, Softmin_to_hfai),
             (torch.nn.Softmax, Softmax_to_hfai),
             (torch.nn.Softmax2d, Softmax2d_to_hfai),
             (torch.nn.LogSoftmax, LogSoftmax_to_hfai),
             (torch.nn.Hardshrink, Hardshrink_to_hfai),
             (torch.nn.Softshrink, Softshrink_to_hfai),
             (torch.nn.Threshold, Threshold_to_hfai),
             (torch.nn.RReLU, RReLU_to_hfai),
             (torch.nn.LeakyReLU, LeakyReLU_to_hfai),
             (torch.nn.Hardsigmoid, Hardsigmoid_to_hfai)]

torch_list = [(MultiheadAttention, MultiheadAttention_to_torch),
              (LayerNorm, LayerNorm_to_torch),
              (LSTM, LSTM_to_torch),
              (GRU, GRU_to_torch),
              (Dropout, Dropout_to_torch),
              (ReLU, ReLU_to_torch),
              (Hardtanh, Hardtanh_to_torch),
              (ReLU6, ReLU6_to_torch),
              (Softplus, Softplus_to_torch),
              (Softmin, Softmin_to_torch),
              (Softmax, Softmax_to_torch),
              (Softmax2d, Softmax2d_to_torch),
              (LogSoftmax, LogSoftmax_to_torch),
              (Hardshrink, Hardshrink_to_torch),
              (Softshrink, Softshrink_to_torch),
              (Threshold, Threshold_to_torch),
              (RReLU, RReLU_to_torch),
              (LeakyReLU, LeakyReLU_to_torch),
              (Hardsigmoid, Hardsigmoid_to_torch)]


def _to_hfai(model, verbose, prefix, ignore):
    for name, module in model.named_children():
        in_list = False
        for x in hfai_list:
            if type(module) == x[0]:
                if x[0] in ignore:
                    continue
                in_list = True
                try:
                    temp_module = x[1](module)
                    if verbose and not temp_module is module:
                        print(f'{prefix}.{name} convert to hfai! type:{x[0].__name__}')
                    model.add_module(name, temp_module)
                except:
                    model.add_module(name, module)

        if in_list == False:
            model.add_module(name,
                             _to_hfai(module, verbose, prefix + '.' + name, ignore))

    return model


[docs]def to_hfai(model, contiguous_param=False, verbose=False, inplace=False, ignore=[]): """ 将模型中 torch 算子替换成 hfai 优化算子 Args: model (nn.Module): 要替换的 model contiguous_param (bool): 是否将 model 参数变成连续,以加速 optimizer,但目前不支持部分情形(默认为 ``False``) verbose (bool): 是否打印替换了的 Layer(默认为 ``False``) inplace (bool): 是否 inplace,不 inplace 会 deepcopy 一个新的 model(默认为 ``False``) ignore (dict): 不需要转化的Layer(默认为 ``None``) Returns: model (nn.Module): 返回替换了 hfai 算子的模型 .. note:: 不会转化 torch 模型继承出的子类 比如: ``class M(torch.nn.LSTM)``, 则 ``M`` 的实例 ``m = M(...)`` 不会转化为 ``hfai.nn.LSTM`` 的实例 Examples: .. code-block:: python from hfai.nn import to_hfai torch_model = Model(...) hfai_model = to_hfai(torch_model, contiguous_param=False, verbose=False, inplace=False, ignore=[torch.nn.Dropout]) """ training_type = model.training if inplace: model_copy = model else: model_copy = copy.deepcopy(model) prefix = 'Model' one_layer = False # 单独一层算子转化方法 for x in hfai_list: if type(model_copy) == x[0]: if x[0] in ignore: continue if inplace: raise ValueError( "one layer module can't be converted when inplace=True,please use inplace=False" ) one_layer = True try: model_copy = x[1](model_copy) if verbose and not model_copy is model: print(f'{prefix} convert to hfai! type:{x[0].__name__}') except: pass if one_layer is False: model_copy = _to_hfai(model_copy, verbose, prefix, ignore) if contiguous_param: p_type = list(model_copy.parameters())[0].dtype p_device = list(model_copy.parameters())[0].device size = sum(p.numel() for p in list(model_copy.parameters())) model_copy.param_buffer = torch.zeros(size, dtype=p_type, device=p_device) model_copy.grad_buffer = torch.zeros(size, dtype=p_type, device=p_device) index = 0 for p in list(model_copy.parameters()): size_p = p.numel() model_copy.param_buffer[index:index + size_p] = p.data.view(-1) p.data = model_copy.param_buffer[index:index + size_p].view( p.data.shape) p.grad = model_copy.grad_buffer[index:index + size_p].view( p.data.shape) index = index + size_p model_copy.param_buffer.grad = model_copy.grad_buffer model_copy.contiguous_param = [model_copy.param_buffer] model_copy.train(training_type) return model_copy
def _to_torch(model, verbose, prefix): for name, module in model.named_children(): in_list = False for x in torch_list: if type(module) == x[0]: in_list = True try: temp_module = x[1](module) if verbose and not temp_module is module: print(f'{prefix}.{name} convert to torch! type:{x[0].__name__}') model.add_module(name, temp_module) except: model.add_module(name, module) if in_list == False: model.add_module(name, _to_torch(module, verbose, prefix + '.' + name)) return model
[docs]def to_torch(model, verbose=False): """ 将 model 中 hfai 算子替换成 torch 的算子 Args: model (nn.Module): 要替换的 model verbose (bool): 是否打印替换了的 Layer(默认为 ``False``) Returns: model (nn.Module): 返回替换了 torch 算子的模型 Examples: .. code-block:: python from hfai.nn import to_torch torch_model = to_torch(hfai_model, verbose=False) """ model_copy = copy.deepcopy(model) prefix = 'Model' # 单独一层算子转化方法 for x in torch_list: if type(model_copy) == x[0]: try: model_copy = x[1](model_copy) if verbose and not model_copy is model: print(f'{prefix} convert to torch! type:{x[0].__name__}') except: pass model_copy.train(model.training) return _to_torch(model_copy, verbose, prefix)