Shortcuts

hfai.nn

to_hfai

将模型中 torch 算子替换成 hfai 优化算子

to_torch

将 model 中 hfai 算子替换成 torch 的算子

bench

给定一个模型或函数和它的输入,对 forward 和 backward 进行计时。

compare

比较两个模型或函数的性能。

LSTM

高效的 LSTM 算子

LSTM_fullc

高效的 LSTM 算子,并输出完整的 c

GRU

高效的 GRU 算子

LayerNorm

更高效的LayerNorm算子

MultiheadAttention

更高效的MultiheadAttention算子

Dropout

压位 Dropout 算子, 训练时的 mask 用 1bit 储存, 以节省训练时的内存

Hardtanh

压位 Hardtanh 算子, 训练时的中间结果用 1bit 储存 [min_val <= x <= max_val], 以节省训练时的内存

LogSoftmax

LogSoftmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

Softmax

Softmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

Softmax2d

Softmax2d 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

Softmin

Softmin 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

Softplus

Softplus 算子, 与 torch 相比, 不用在训练中保存 output, 节省一倍内存

ReLU

压位 ReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

ReLU6

压位 ReLU6 算子, 训练时的中间结果用 1bit 储存 [0 <= x <= 6], 以节省训练时的内存

Threshold

压位 Threshold 算子, 训练时的中间结果用 1bit 储存 [x > threshold], 以节省训练时的内存

LeakyReLU

压位 LeakyReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

RReLU

压位 RReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

Hardsigmoid

压位 Hardsigmoid 算子, 训练时的中间结果用 1bit 储存 [-3 <= x <= 3], 以节省训练时的内存

Hardshrink

压位 Hardshrink 算子, 训练时的中间结果用 1bit 储存 [-lambda <= x <= lambda], 以节省训练时的内存

Softshrink

压位 Softshrink 算子, 训练时的中间结果用 1bit 储存 [-lambda <= x <= lambda], 以节省训练时的内存

hfai.nn.to_hfai(model, contiguous_param=False, verbose=False, inplace=False, ignore=[])[source]

将模型中 torch 算子替换成 hfai 优化算子

Parameters
  • model (nn.Module) – 要替换的 model

  • contiguous_param (bool) – 是否将 model 参数变成连续,以加速 optimizer,但目前不支持部分情形(默认为 False

  • verbose (bool) – 是否打印替换了的 Layer(默认为 False

  • inplace (bool) – 是否 inplace,不 inplace 会 deepcopy 一个新的 model(默认为 False

  • ignore (dict) – 不需要转化的Layer(默认为 None)

Returns

返回替换了 hfai 算子的模型

Return type

model (nn.Module)

Note

不会转化 torch 模型继承出的子类

比如: class M(torch.nn.LSTM), 则 M 的实例 m = M(...) 不会转化为 hfai.nn.LSTM 的实例

Examples:

from hfai.nn import to_hfai

torch_model = Model(...)
hfai_model = to_hfai(torch_model, contiguous_param=False, verbose=False, inplace=False, ignore=[torch.nn.Dropout])
hfai.nn.to_torch(model, verbose=False)[source]

将 model 中 hfai 算子替换成 torch 的算子

Parameters
  • model (nn.Module) – 要替换的 model

  • verbose (bool) – 是否打印替换了的 Layer(默认为 False

Returns

返回替换了 torch 算子的模型

Return type

model (nn.Module)

Examples:

from hfai.nn import to_torch

torch_model = to_torch(hfai_model, verbose=False)
hfai.nn.bench(model, inputs, optimizer=None, iters=100, warmup_iters=10, verbose=True, return_results=False, only_fwd=False)[source]

给定一个模型或函数和它的输入,对 forward 和 backward 进行计时。

Examples

>>> from torchvision import models
>>> model = models.resnet18().cuda(0)
>>> x = torch.randn(16, 3, 224, 224).cuda(0)
>>> hfai.nn.bench(model, (x,), iters=100)
+---------------+-------------+
| measurement   | time (us)   |
+===============+=============+
| forward       | 6241.785    |
+---------------+-------------+
| backward      | 8546.621    |
+---------------+-------------+
| fwd + bwd     | 14788.406   |
+---------------+-------------+
Parameters
  • model – PyTorch 模型或函数

  • inputs (tuple) – 输入,必须和 model 在同一个 device 上

  • optimizer (torch.optim.Optimizer) – 优化器对象。默认是 None

  • iters (int) – 迭代的次数,默认是 100

  • warmup_iters (int) – 预热的迭代次数,默认是 10

  • verbose (bool) – 是否打印计时的结果,默认是 True

  • return_results (bool) – 是否返回各项指标的结果,默认是 False

  • only_fwd (bool) – 是否只做 forward,默认是 False

hfai.nn.compare(model1, model2, inputs1, inputs2, optimizer1=None, optimizer2=None, iters=100, warmup_iters=10, compare_loss=True, rtol=1e-07, atol=1e-10, only_fwd=False)[source]

比较两个模型或函数的性能。

Examples

>>> import copy
>>> from torchvision import models
>>> model1 = models.resnet18().cuda(0)
>>> model2 = copy.deepcopy(model1).cuda(1)
>>> x1 = torch.randn(16, 3, 224, 224).cuda(0)
>>> x2 = x1.clone().cuda(1)
>>> hfai.nn.compare(model1, model2, (x1,), (x2,), iters=100)
+---------------+--------------------+--------------------+------------------------+
| measurement   | model1 time (us)   | model2 time (us)   | model1 / model2 time   |
+===============+====================+====================+========================+
| forward       | 5363.802           | 5450.234           | 98.41 %                |
+---------------+--------------------+--------------------+------------------------+
| backward      | 8114.942           | 7987.469           | 101.60 %               |
+---------------+--------------------+--------------------+------------------------+
| fwd + bwd     | 13478.745          | 13437.703          | 100.31 %               |
+---------------+--------------------+--------------------+------------------------+
Parameters
  • model1 – 模型或函数1

  • model2 – 模型或函数2

  • inputs1 (tuple) – model1 的输入,必须和 model1 在同一个 device 上

  • inputs2 (tuple) – model2 的输入,必须和 model2 在同一个 device 上

  • optimizer1 (torch.optim.Optimizer) – 优化器对象1。默认是 None

  • optimizer2 (torch.optim.Optimizer) – 优化器对象2。默认是 None

  • iters (int) – 迭代的次数,默认是 100

  • warmup_iters (int) – 预热的迭代次数,默认是 10

  • compare_loss (bool) – 如果是 True,还会比较两个模型或函数的输出结果和 backward 的梯度是否相同,默认是 True

  • rtol (float) – 允许的最大相对误差,默认是 1e-7

  • atol (float) – 允许的最大绝对误差,默认是 1e-10

  • only_fwd (bool) – 是否只比较 forward,默认是 False

class hfai.nn.LSTM(*args, **kwargs)[source]

高效的 LSTM 算子

使用方式与 PyTorch 的 LSTM 算子 一致

不支持 proj_size 参数

Note

额外支持 drop_connect 参数. 如果 0 < drop_connect <= 1, 会在所有的 weight_hh 后面紧接着增加一层 Dropout(p=drop_connect)

Note

支持 3 种精度模式:

  1. TF32 模式 (默认): LSTM 中的矩阵乘法使用 TF32 加速

    batch_size <= 64 && hidden_size <= 1728batch_size <= 512 && hidden_size <= 512 时, LSTM 使用 persistent 方法加速

  2. Float32 模式: LSTM 中的矩阵乘法使用完整精度

    需要指定 torch.backends.cuda.matmul.allow_tf32 = False

    batch_size <= 16 && hidden_size <= 1728 时, LSTM 使用 persistent 方法加速

  3. BFloat16 模式: LSTM 中的矩阵乘法使用 BFloat16 加速

    需要指定 hfai.nn.context.SetRnnAllowConversion(True)batch_size <= 72 && hidden_size <= 1728

Note

hidden_size 是 64 的倍数时性能最好

Examples:

lstm = hfai.nn.LSTM(input_size=10, hidden_size=20).cuda()

input0 = torch.randn(5, 100, 10).cuda()
output, (hn, cn) = lstm(input0, None)  # TF32 模式, 不使用 persistent 方法

hfai.nn.context.SetRnnAllowConversion(True)
input1 = torch.randn(5, 64, 10).cuda()
output, (hn, cn) = lstm(input1, None)  # BFloat16 模式, 使用 persistent 方法
hfai.nn.context.SetRnnAllowConversion(False)

input2 = torch.randn(5, 8, 10).cuda()
output, (hn, cn) = lstm(input2, None)  # TF32 模式, 使用 persistent 方法
class hfai.nn.LSTM_fullc(*args, **kwargs)[source]

高效的 LSTM 算子,并输出完整的 c

模型参数和 Inputs 与 PyTorch 的 LSTM 算子 一致

不支持 proj_size 参数

Outputs: output, (h_n, c_n), full_c
  • output: 与 PyTorch 一致

  • h_n: 与 PyTorch 一致

  • c_n: 与 PyTorch 一致

  • full_c: $(seq_len, D * num_layers, batch_size, hidden_size)$, 包含了完整的共 seq_len 层的 c

Note

额外支持 drop_connect 参数. 如果 0 < drop_connect <= 1, 会在所有的 weight_hh 后面紧接着增加一层 Dropout(p=drop_connect)

Note

支持 3 种精度模式:

  1. TF32 模式 (默认): LSTM 中的矩阵乘法使用 TF32 加速

    batch_size <= 64 && hidden_size <= 1728 时, LSTM 使用 persistent 方法加速

  2. Float32 模式: LSTM 中的矩阵乘法使用完整精度

    需要指定 torch.backends.cuda.matmul.allow_tf32 = False

    batch_size <= 16 && hidden_size <= 1728 时, LSTM 使用 persistent 方法加速

  3. BFloat16 模式: LSTM 中的矩阵乘法使用 BFloat16 加速

    需要指定 hfai.nn.context.SetRnnAllowConversion(True)batch_size <= 72 && hidden_size <= 1728

Note

hidden_size 是 64 的倍数时性能最好

Examples:

lstm_fullc = hfai.nn.LSTM_fullc(input_size=10, hidden_size=20).cuda()

input0 = torch.randn(5, 100, 10).cuda()
output, (hn, cn), full_c = lstm_fullc(input0, None)  # TF32 模式, 不使用 persistent 方法

hfai.nn.context.SetRnnAllowConversion(True)
input1 = torch.randn(5, 64, 10).cuda()
output, (hn, cn), full_c = lstm_fullc(input1, None)  # BFloat16 模式, 使用 persistent 方法
hfai.nn.context.SetRnnAllowConversion(False)

input2 = torch.randn(5, 8, 10).cuda()
output, (hn, cn), full_c = lstm_fullc(input2, None)  # TF32 模式, 使用 persistent 方法
class hfai.nn.GRU(*args, **kwargs)[source]

高效的 GRU 算子

使用方式与 PyTorch 的 GRU 算子 一致

Note

额外支持 drop_connect 参数. 如果 0 < drop_connect <= 1, 会在所有的 weight_hh 后面紧接着增加一层 Dropout(p=drop_connect)

Note

支持 2 种精度模式:

  1. TF32 模式 (默认): GRU 中的矩阵乘法使用 TF32 加速

    batch_size <= 64 && hidden_size <= 1728 时, GRU 使用 persistent 方法加速

  2. Float32 模式: GRU 中的矩阵乘法使用完整精度

    需要指定 torch.backends.cuda.matmul.allow_tf32 = False

Note

hidden_size 是 64 的倍数时性能最好

Examples:

gru = hfai.nn.GRU(input_size=10, hidden_size=20).cuda()

input0 = torch.randn(5, 100, 10).cuda()
output, hn = gru(input0, None)  # TF32 模式, 不使用 persistent 方法

input2 = torch.randn(5, 8, 10).cuda()
output, hn = gru(input2, None)  # TF32 模式, 使用 persistent 方法
class hfai.nn.LayerNorm(normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None)[source]

更高效的LayerNorm算子

接口和 PyTorch的LayerNorm算子 一致

class hfai.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, need_softmax_temperature=False, compact=False, device=None, dtype=None)[source]

更高效的MultiheadAttention算子

目前支持 qkvembed_dim 相同且 add_zero_attn=False,其余和 PyTorch的MultiheadAttention 算子一致

Note

支持两种额外扩展模式

  1. 低精度模式,进一步提升性能

    指定 hfai.nn.context.SetAttnAllowConversion(True) 时, MultiheadAttention 在保证误差的情况下对于部分矩阵乘法降低精度进行运算, 进一步提升性能

  2. 低显存模式,适应 batch_size 或者参数更大的训练

    设置 MultiheadAttention 中 compact 参数为True (default: False)

Examples:

attn = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1).cuda()

# 低显存模式
attn_compact = hfai.nn.MultiheadAttention(embed_dim=10, num_heads=2, dropout=0.1, compact=True).cuda()

query0 = torch.randn(5, 4, 10).cuda()
key0 = torch.randn(3, 4, 10).cuda()
value0 = torch.randn(3, 4, 10).cuda()

output0 = attn(query0, key0, value0)[0]

# 低精度模式
hfai.nn.context.SetAttnAllowConversion(True)
query1 = torch.randn(5, 4, 10).cuda()
key1 = torch.randn(3, 4, 10).cuda()
value1 = torch.randn(3, 4, 10).cuda()

output1 = attn(query1, key1, value1)[0]
hfai.nn.context.SetRnnAllowConversion(False)
class hfai.nn.Dropout(p=0.5, inplace=False)[source]

压位 Dropout 算子, 训练时的 mask 用 1bit 储存, 以节省训练时的内存

使用方式与 PyTorch 的 Dropout 一致

Parameters
  • p (float, optional) – 元素为零的概率. 默认: 0.5

  • inplace (bool, optional) – 如果是 True, 进行原地操作. 默认: False

Examples:

m = hfai.nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)
class hfai.nn.Hardtanh(min_val=- 1.0, max_val=1.0, inplace=False, min_value=None, max_value=None)[source]

压位 Hardtanh 算子, 训练时的中间结果用 1bit 储存 [min_val <= x <= max_val], 以节省训练时的内存

使用方式与 PyTorch 的 Hardtanh 一致

Examples:

m = hfai.nn.Hardtanh(-2, 2)
input = torch.randn(20, 16, 10)
output = m(input)
class hfai.nn.LogSoftmax(dim=None)[source]

LogSoftmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

使用方式与 PyTorch 的 LogSoftmax 一致

Examples:

m = hfai.nn.LogSoftmax()
input = torch.randn(2, 3)
output = m(input)
class hfai.nn.Softmax(dim=None)[source]

Softmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

使用方式与 PyTorch 的 Softmax 一致

Examples:

m = hfai.nn.Softmax(dim=1)
input = torch.randn(2, 3)
output = m(input)
class hfai.nn.Softmax2d[source]

Softmax2d 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

使用方式与 PyTorch 的 Softmax2d 一致

Examples:

m = hfai.nn.Softmax2d()
input = torch.randn(2, 3, 12, 13)
output = m(input)
class hfai.nn.Softmin(dim=None)[source]

Softmin 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存

使用方式与 PyTorch 的 Softmin 一致

Examples:

m = hfai.nn.Softmin(dim=1)
input = torch.randn(2, 3)
output = m(input)
class hfai.nn.Softplus(beta=1, threshold=20)[source]

Softplus 算子, 与 torch 相比, 不用在训练中保存 output, 节省一倍内存

使用方式与 PyTorch 的 Softplus 一致

Examples:

m = hfai.nn.Softplus(beta=2, threshold=19)
input = torch.randn(2, 3)
output = m(input)
class hfai.nn.ReLU(inplace=False)[source]

压位 ReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

使用方式与 PyTorch 的 ReLU 一致

Examples:

m = hfai.nn.ReLU(p=0.2)
input = torch.randn(20, 16, 10)
output = m(input)
class hfai.nn.ReLU6(inplace=False)[source]

压位 ReLU6 算子, 训练时的中间结果用 1bit 储存 [0 <= x <= 6], 以节省训练时的内存

使用方式与 PyTorch 的 ReLU6 一致

Examples:

m = hfai.nn.ReLU6()
input = torch.randn(20, 16, 10)
output = m(input)
class hfai.nn.Threshold(threshold, value, inplace=False)[source]

压位 Threshold 算子, 训练时的中间结果用 1bit 储存 [x > threshold], 以节省训练时的内存

使用方式与 PyTorch 的 Threshold 一致

Examples:

m = hfai.nn.Threshold(0.1, 20)
input = torch.randn(2)
output = m(input)
class hfai.nn.LeakyReLU(negative_slope=0.01, inplace=False)[source]

压位 LeakyReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

使用方式与 PyTorch 的 LeakyReLU 一致

Examples:

m = hfai.nn.LeakyReLU(0.1)
input = torch.randn(2)
output = m(input)
class hfai.nn.RReLU(lower=0.125, upper=0.3333333333333333, inplace=False)[source]

压位 RReLU 算子, 训练时的中间结果用 1bit 储存 [x >= 0], 以节省训练时的内存

使用方式与 PyTorch 的 RReLU 一致

Examples:

m = hfai.nn.RReLU(0.1, 0.3)
input = torch.randn(2)
output = m(input)
class hfai.nn.Hardsigmoid(inplace=False)[source]

压位 Hardsigmoid 算子, 训练时的中间结果用 1bit 储存 [-3 <= x <= 3], 以节省训练时的内存

使用方式与 PyTorch 的 Hardsigmoid 一致

Examples:

m = hfai.nn.Hardsigmoid()
input = torch.randn(2)
output = m(input)
class hfai.nn.Hardshrink(lambd=0.5)[source]

压位 Hardshrink 算子, 训练时的中间结果用 1bit 储存 [-lambda <= x <= lambda], 以节省训练时的内存

使用方式与 PyTorch 的 Hardshrink 一致

Examples:

m = hfai.nn.Hardshrink(lambda=0.6)
input = torch.randn(2)
output = m(input)
class hfai.nn.Softshrink(lambd=0.5)[source]

压位 Softshrink 算子, 训练时的中间结果用 1bit 储存 [-lambda <= x <= lambda], 以节省训练时的内存

使用方式与 PyTorch 的 Softshrink 一致

Examples:

m = hfai.nn.Softshrink(lambda=0.6)
input = torch.randn(2)
output = m(input)