Source code for hfai.nn.modules.activation
from torch import nn
import torch
from .. import functional as F
[docs]class ReLU(nn.ReLU):
"""
压位 ReLU 算子, 训练时的中间结果用 1bit 储存 `[x >= 0]`, 以节省训练时的内存
使用方式与 `PyTorch 的 ReLU <https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.ReLU(p=0.2)
input = torch.randn(20, 16, 10)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.relu(input, self.inplace)
[docs]class Hardtanh(nn.Hardtanh):
"""
压位 Hardtanh 算子, 训练时的中间结果用 1bit 储存 `[min_val <= x <= max_val]`, 以节省训练时的内存
使用方式与 `PyTorch 的 Hardtanh <https://pytorch.org/docs/stable/generated/torch.nn.Hardtanh.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Hardtanh(-2, 2)
input = torch.randn(20, 16, 10)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.hardtanh(input, self.min_val, self.max_val, self.inplace)
[docs]class ReLU6(nn.ReLU6):
"""
压位 ReLU6 算子, 训练时的中间结果用 1bit 储存 `[0 <= x <= 6]`, 以节省训练时的内存
使用方式与 `PyTorch 的 ReLU6 <https://pytorch.org/docs/stable/generated/torch.nn.ReLU6.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.ReLU6()
input = torch.randn(20, 16, 10)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.relu6(input, self.inplace)
[docs]class Softplus(nn.Softplus):
"""
Softplus 算子, 与 torch 相比, 不用在训练中保存 output, 节省一倍内存
使用方式与 `PyTorch 的 Softplus <https://pytorch.org/docs/stable/generated/torch.nn.Softplus.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Softplus(beta=2, threshold=19)
input = torch.randn(2, 3)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.softplus(input, self.beta, self.threshold)
[docs]class Softmin(nn.Softmin):
"""
Softmin 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存
使用方式与 `PyTorch 的 Softmin <https://pytorch.org/docs/stable/generated/torch.nn.Softmin.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Softmin(dim=1)
input = torch.randn(2, 3)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.softmin(input, self.dim, _stacklevel=5)
[docs]class Softmax(nn.Softmax):
"""
Softmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存
使用方式与 `PyTorch 的 Softmax <https://pytorch.org/docs/stable/generated/torch.nn.Softmax.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Softmax(dim=1)
input = torch.randn(2, 3)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.softmax(input, self.dim, _stacklevel=5)
[docs]class Softmax2d(nn.Softmax2d):
"""
Softmax2d 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存
使用方式与 `PyTorch 的 Softmax2d <https://pytorch.org/docs/stable/generated/torch.nn.Softmax2d.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Softmax2d()
input = torch.randn(2, 3, 12, 13)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
assert input.dim() == 4 or input.dim() == 3, 'Softmax2d requires a 3D or 4D tensor as input'
return F.softmax(input, dim=-3, _stacklevel=5)
[docs]class LogSoftmax(nn.LogSoftmax):
"""
LogSoftmax 算子, 与 torch 相比, 不用在训练中保存 input, 节省一倍内存
使用方式与 `PyTorch 的 LogSoftmax <https://pytorch.org/docs/stable/generated/torch.nn.LogSoftmax.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.LogSoftmax()
input = torch.randn(2, 3)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.log_softmax(input, self.dim, _stacklevel=5)
[docs]class Threshold(nn.Threshold):
"""
压位 Threshold 算子, 训练时的中间结果用 1bit 储存 `[x > threshold]`, 以节省训练时的内存
使用方式与 `PyTorch 的 Threshold <https://pytorch.org/docs/stable/generated/torch.nn.Threshold.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Threshold(0.1, 20)
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.threshold(input, self.threshold, self.value, self.inplace)
[docs]class RReLU(nn.RReLU):
"""
压位 RReLU 算子, 训练时的中间结果用 1bit 储存 `[x >= 0]`, 以节省训练时的内存
使用方式与 `PyTorch 的 RReLU <https://pytorch.org/docs/stable/generated/torch.nn.RReLU.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.RReLU(0.1, 0.3)
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.rrelu(input, self.lower, self.upper, self.training, self.inplace)
[docs]class LeakyReLU(nn.LeakyReLU):
"""
压位 LeakyReLU 算子, 训练时的中间结果用 1bit 储存 `[x >= 0]`, 以节省训练时的内存
使用方式与 `PyTorch 的 LeakyReLU <https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.LeakyReLU(0.1)
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.leaky_relu(input, self.negative_slope, self.inplace)
[docs]class Hardsigmoid(nn.Hardsigmoid):
"""
压位 Hardsigmoid 算子, 训练时的中间结果用 1bit 储存 `[-3 <= x <= 3]`, 以节省训练时的内存
使用方式与 `PyTorch 的 Hardsigmoid <https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Hardsigmoid()
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.hardsigmoid(input, self.inplace)
[docs]class Hardshrink(nn.Hardshrink):
"""
压位 Hardshrink 算子, 训练时的中间结果用 1bit 储存 `[-lambda <= x <= lambda]`, 以节省训练时的内存
使用方式与 `PyTorch 的 Hardshrink <https://pytorch.org/docs/stable/generated/torch.nn.Hardshrink.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Hardshrink(lambda=0.6)
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.hardshrink(input, self.lambd)
[docs]class Softshrink(nn.Softshrink):
"""
压位 Softshrink 算子, 训练时的中间结果用 1bit 储存 `[-lambda <= x <= lambda]`, 以节省训练时的内存
使用方式与 `PyTorch 的 Softshrink <https://pytorch.org/docs/stable/generated/torch.nn.Softshrink.html>`_ 一致
Examples:
.. code-block:: python
m = hfai.nn.Softshrink(lambda=0.6)
input = torch.randn(2)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.softshrink(input, self.lambd)
class SwiGLU(nn.Module):
def __init__(self, dim=-1):
super().__init__()
self.dim = dim
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.swiglu(input, self.dim)