import torch
import warnings
from torch import Tensor
from typing import Optional, Tuple, List, Union, Callable, TypeVar
from torch.types import Number
from torch import _VF
from torch.overrides import has_torch_function, handle_torch_function
import torch.nn.functional as torch_F
from torch.nn.functional import *
from .sync_function import sync
from .functional_utils import _cuda_only, _disable_set_replace_torch
from . import _pytree as pytree
from functools import partial
try:
import hfai.hfcuda.dropout as hf_dropout
except:
pass
class DropoutFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, p, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, seeds = hf_dropout.dropout_forward(input, p, inplace)
ctx.p = p
ctx.seeds = seeds
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
dinput = hf_dropout.dropout_backward(doutput, ctx.p, ctx.seeds)
ret = dinput, None, None
return ret
[docs]@_cuda_only
def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor:
"""
dropout 函数, 参考 :class:`~hfai.nn.Dropout`
"""
if has_torch_function((input,)):
return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace)
if p < 0. or p > 1.:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
if not training:
return input
if p == 0:
return input
if p == 1:
return input.zero_() if inplace else torch.zeros_like(input)
if not torch.is_grad_enabled() or not input.requires_grad:
return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training)
if inplace and not input.is_contiguous():
return _VF.dropout_(input, p, training)
return DropoutFunction.apply(input, p, inplace)
try:
import hfai.hfcuda.bitmask as hf_bitmask
except:
pass
class ReLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask = hf_bitmask.relu_forward(input, inplace)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.relu_backward(doutput, compressed_mask)
return dinput, None
[docs]@_cuda_only
@_disable_set_replace_torch
def relu(input: Tensor, inplace: bool = False) -> Tensor:
"""
relu 函数, 参考 :class:`~hfai.nn.ReLU`
"""
if has_torch_function((input,)):
return handle_torch_function(relu, (input,), input, inplace=inplace)
if not torch.is_grad_enabled() or not input.requires_grad:
return input.relu_() if inplace else input.relu()
if inplace and not input.is_contiguous():
return input.relu_()
return ReLUFunction.apply(input, inplace)
[docs]@_cuda_only
@_disable_set_replace_torch
def relu_(input: Tensor) -> Tensor:
"""
原地操作的 relu 函数, 参考 :class:`~hfai.nn.ReLU`
"""
return relu(input, inplace=True)
[docs]def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor:
"""
hardtanh 函数, 参考 :class:`~hfai.nn.Hardtanh`
"""
if has_torch_function((input,)):
return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace)
return clamp(input, min_val, max_val, inplace)
[docs]def hardtanh_(input: Tensor, min_val: float = -1., max_val: float = 1.) -> Tensor:
"""
原地操作的 hardtanh 函数, 参考 :class:`~hfai.nn.Hardtanh`
"""
if has_torch_function((input,)):
return handle_torch_function(hardtanh_, (input,), input, min_val=min_val, max_val=max_val)
return clamp_(input, min_val, max_val)
[docs]def relu6(input: Tensor, inplace: bool = False) -> Tensor:
"""
relu6 函数, 参考 :class:`~hfai.nn.ReLU6`
"""
return hardtanh(input, 0, 6, inplace)
try:
import hfai.hfcuda.softplus as hf_softplus
except:
pass
class SoftplusFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, beta, threshold):
output = hf_softplus.softplus_forward(input, beta, threshold)
ctx.beta = beta
ctx.threshold = threshold
ctx.save_for_backward(input)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
input = ctx.saved_tensors[0]
dinput = hf_softplus.softplus_backward(doutput, input, ctx.beta, ctx.threshold)
return dinput, None, None
[docs]@_cuda_only
def softplus(input: Tensor, beta: int = 1, threshold: int = 20) -> Tensor:
"""
softplus 函数, 参考 :class:`~hfai.nn.Softplus`
"""
if has_torch_function((input,)):
return handle_torch_function(softplus, (input,), input, beta=beta, threshold=threshold)
if not torch.is_grad_enabled() or not input.requires_grad:
return torch._C._nn.softplus(input, beta, threshold)
return SoftplusFunction.apply(input, beta, threshold)
def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int:
warnings.warn("Implicit dimension choice for {} has been deprecated. "
"Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel)
if ndim == 0 or ndim == 1 or ndim == 3:
ret = 0
else:
ret = 1
return ret
try:
import hfai.hfcuda.softmax as hf_softmax
except:
pass
class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
output = hf_softmax.softmax_forward(input, dim)
ctx.dim = dim
ctx.input_dtype = input.dtype
ctx.save_for_backward(output)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
output = ctx.saved_tensors[0]
fake_input = torch.empty([1], dtype=ctx.input_dtype, device=output.device)
dinput = hf_softmax.softmax_backward(doutput, output, ctx.dim, fake_input)
return dinput, None
[docs]def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
"""
softmin 函数, 参考 :class:`~hfai.nn.Softmin`
"""
return softmax(-input, dim, _stacklevel, dtype)
[docs]@_disable_set_replace_torch
def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
"""
softmax 函数, 参考 :class:`~hfai.nn.Softmax`
"""
if has_torch_function((input,)):
return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
if dtype is not None:
input = input.type(dtype)
if not torch.is_grad_enabled() or not input.requires_grad:
return input.softmax(dim)
return SoftmaxFunction.apply(input, dim)
class LogSoftmaxFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
output = hf_softmax.log_softmax_forward(input, dim)
ctx.dim = dim
ctx.input_dtype = input.dtype
ctx.save_for_backward(output)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
output = ctx.saved_tensors[0]
fake_input = torch.empty([1], dtype=ctx.input_dtype, device=output.device)
dinput = hf_softmax.log_softmax_backward(doutput, output, ctx.dim, fake_input)
return dinput, None
[docs]@_disable_set_replace_torch
def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor:
"""
log_softmax 函数, 参考 :class:`~hfai.nn.LogSoftmax`
"""
if has_torch_function((input,)):
return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
if dim is None:
dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel)
if dtype is not None:
input = input.type(dtype)
if not torch.is_grad_enabled() or not input.requires_grad:
return input.log_softmax(dim)
return LogSoftmaxFunction.apply(input, dim)
class ClampFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, min, max, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask = hf_bitmask.clamp_forward(input, min, max, inplace)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.clamp_backward(doutput, compressed_mask)
return dinput, None, None, None
[docs]def clamp_min(input: Tensor, min: float, inplace: bool = False) -> Tensor:
"""
压位 clamp_min 算子, 训练时的中间结果用 1bit 储存 `[x >= min]`, 以节省训练时的内存
Args:
input: 输入的 Tensor
min (float): output 的最小值
inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False``
.. code-block:: python
import hfai.nn.functional as F
y = F.clamp_min(x, min=-0.5)
# same as: y = x.clamp_min(min=-0.5)
# same as: y = torch.max(x, -0.5 * torch.ones_like(x))
"""
return clamp(input, min=min, inplace=inplace)
[docs]def clamp_min_(input: Tensor, min: float) -> Tensor:
"""
原地操作的 clamp_min 函数
"""
return clamp_min(input, min=min, inplace=True)
[docs]def clamp_max(input: Tensor, max: float, inplace: bool = False) -> Tensor:
"""
压位 clamp_max 算子, 训练时的中间结果用 1bit 储存 `[x <= max]`, 以节省训练时的内存
Args:
input: 输入的 Tensor
max (float): output 的最大值
inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False``
.. code-block:: python
import hfai.nn.functional as F
y = F.clamp_max(x, max=0.5)
# same as: y = x.clamp_max(max=0.5)
# same as: y = torch.min(x, 0.5 * torch.ones_like(x))
"""
return clamp(input, max=max, inplace=inplace)
[docs]def clamp_max_(input: Tensor, max: float) -> Tensor:
"""
原地操作的 clamp_max 函数
"""
return clamp(input, max=max, inplace=True)
[docs]@_cuda_only
@_disable_set_replace_torch
def clamp(input: Tensor, min: Optional[float] = None, max: Optional[float] = None, inplace: bool = False) -> Tensor:
"""
压位 clamp 算子, 训练时的中间结果用 1bit 储存 `[min <= x <= max]`, 以节省训练时的内存
Args:
input: 输入的 Tensor
min (float): output 的最小值, 默认: ``None``
max (float): output 的最大值, 默认: ``None``
inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False``
.. code-block:: python
import hfai.nn.functional as F
y = F.clamp(x, min=-0.5, max=-0.5)
# same as: y = x.clamp(min=-0.5, max=0.5)
"""
if has_torch_function((input,)):
return handle_torch_function(clamp, (input,), input, min=min, max=max, inplace=inplace)
if min is None and max is None:
raise RuntimeError("At least one of 'min' or 'max' must not be None")
if min is None:
min = float('-inf')
if max is None:
max = float('inf')
if not torch.is_grad_enabled() or not input.requires_grad:
return input.clamp_(min, max) if inplace else input.clamp(min, max)
if inplace and not input.is_contiguous():
return input.clamp_(min, max)
return ClampFunc.apply(input, min, max, inplace)
[docs]@_cuda_only
@_disable_set_replace_torch
def clamp_(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor:
"""
原地操作的 clamp 函数
"""
return clamp(input, min, max, inplace=True)
[docs]def clip(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor:
"""
clip 函数, 参考 :func:`hfai.nn.functional.clamp`
"""
return clamp(input, min, max)
[docs]def clip_(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor:
"""
原地操作的 clip 函数
"""
return clamp_(input, min, max)
class ThresholdFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, threshold, value, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask = hf_bitmask.threshold_forward(input, threshold, value, inplace)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.threshold_backward(doutput, compressed_mask)
return dinput, None, None, None
@_cuda_only
def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor:
"""
threshold 函数, 参考 :class:`~hfai.nn.Threshold`
"""
if has_torch_function((input,)):
return handle_torch_function(_threshold, (input,), input, threshold=threshold, value=value, inplace=inplace)
if not torch.is_grad_enabled() or not input.requires_grad:
return _VF.threshold_(input, threshold, value) if inplace else _VF.threshold(input, threshold, value)
if inplace and not input.is_contiguous():
return _VF.threshold_(input, threshold, value)
return ThresholdFunc.apply(input, threshold, value, inplace)
# We define this function as _threshold because it takes an argument
# named threshold, which clobbers the recursive reference to the
# function needed for __torch_function__ support
threshold = _threshold
[docs]def threshold_(input: Tensor, threshold: float, value: float) -> Tensor:
"""
原地操作的 threshold 函数, 参考 :class:`~hfai.nn.Threshold`
"""
return _threshold(input, threshold, value, inplace=True)
[docs]@_disable_set_replace_torch
def max(input: Tensor, value: Optional[Union[torch.Tensor, int, float]] = None, dim: Optional[int] = None,
keepdim: bool = False) -> Tensor:
"""
压位 max 算子, 返回 input 和 max 中的较大值
若 value 是 Tensor, 调用 hf_F.maximum(input, value)
若 value 是 float, 调用 hf_F.clamp(input, min=value)
若 value 是 int, 或者 dim 或 keepdim 不为 None, 调用 torch.max(input, dim=value or dim, keepdim=keepdim)
Args:
input (Tensor): 输入的 Tensor
value (Tensor or float or int): 进行比较的值或做 max 操作的维度
dim (int): 做操作的维度
keepdim (bool): 是否 keepdim
"""
if value is None and dim is None:
return input.max()
if dim is not None:
return input.max(dim=dim, keepdim=keepdim)
if isinstance(value, int):
return input.max(dim=value, keepdim=keepdim)
if isinstance(value, torch.Tensor):
return maximum(input, value)
return clamp(input, min=value)
[docs]@_disable_set_replace_torch
def min(input: Tensor, value: Optional[Union[torch.Tensor, int, float]] = None, dim: Optional[int] = None,
keepdim: bool = False) -> Tensor:
"""
压位 min 算子, 返回 input 和 min 中的较大值
若 value 是 Tensor, 调用 hf_F.minimum(input, value)
若 value 是 float, 调用 hf_F.clamp(input, max=value)
若 value 是 int, 或者 dim 或 keepdim 不为 None, 调用 torch.min(input, dim=value or dim, keepdim=keepdim)
Args:
input (Tensor): 输入的 Tensor
value (Tensor or float or int): 进行比较的值或做 min操作的维度
dim (int): 做操作的维度
keepdim (bool): 是否 keepdim
"""
if value is None and dim is None:
return input.min()
if dim is not None:
return input.min(dim=dim, keepdim=keepdim)
if isinstance(value, int):
return input.min(dim=value, keepdim=keepdim)
if isinstance(value, torch.Tensor):
return minimum(input, value)
return clamp(input, max=value)
class RReluFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, lower, upper, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, noise, compressed_mask = hf_bitmask.rrelu_forward(input, lower, upper, inplace)
ctx.save_for_backward(noise, compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
noise, compressed_mask = ctx.saved_tensors
dinput = hf_bitmask.rrelu_backward(doutput, noise, compressed_mask)
return dinput, None, None, None
[docs]@_cuda_only
def rrelu(
input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False
) -> Tensor:
"""
rrelu 函数, 参考 :class:`~hfai.nn.RReLU`
"""
if has_torch_function((input,)):
return handle_torch_function(rrelu, (input,), input, lower=lower, upper=upper, training=training,
inplace=inplace)
if not training or not input.requires_grad or not torch.is_grad_enabled():
return torch._C._VariableFunctionsClass.rrelu_(input, lower, upper, training) if inplace \
else torch._C._VariableFunctionsClass.rrelu(input, lower, upper, training)
if inplace and not input.is_contiguous():
return torch._C._VariableFunctionsClass.rrelu_(input, lower, upper, training)
return RReluFunc.apply(input, lower, upper, inplace)
[docs]def rrelu_(
input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False
) -> Tensor:
"""
原地操作的 rrelu 函数
"""
return rrelu(input, lower, upper, training)
class LeakyReluFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, negative_slope, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask = hf_bitmask.leaky_relu_forward(input, negative_slope, inplace)
ctx.save_for_backward(compressed_mask)
ctx.negative_slope = negative_slope
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
negative_slope = ctx.negative_slope
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.leaky_relu_backward(doutput, compressed_mask, negative_slope)
return dinput, None, None, None
[docs]@_cuda_only
def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor:
"""
leaky_relu 函数, 参考 :class:`~hfai.nn.LeakyReLU`
"""
if has_torch_function((input,)):
return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace)
if not torch.is_grad_enabled() or not input.requires_grad:
return torch._C._nn.leaky_relu_(input, negative_slope) if inplace \
else torch._C._nn.leaky_relu(input, negative_slope)
if inplace and not input.is_contiguous():
return torch._C._nn.leaky_relu_(input, negative_slope)
return LeakyReluFunc.apply(input, negative_slope, inplace)
[docs]def leaky_relu_(input: Tensor, negative_slope: float = 0.01) -> Tensor:
"""
原地操作的 leaky_relu 函数
"""
return leaky_relu(input, negative_slope, inplace=True)
class HardSigmoid(torch.autograd.Function):
@staticmethod
def forward(ctx, input, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask = hf_bitmask.hardsigmoid_forward(input, inplace)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.hardsigmoid_backward(doutput, compressed_mask)
return dinput, None
[docs]@_cuda_only
def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor:
"""
hardsigmoid 函数, 参考 :class:`~hfai.nn.Hardsigmoid`
"""
if has_torch_function((input,)):
return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace)
if not torch.is_grad_enabled() or not input.requires_grad:
return torch._C._nn.hardsigmoid_(input) if inplace else torch._C._nn.hardsigmoid(input)
if inplace and not input.is_contiguous():
return torch._C._nn.hardsigmoid_(input)
return HardSigmoid.apply(input, inplace)
class HardshrinkFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, lambd):
input = input.contiguous()
output, compressed_mask = hf_bitmask.hardshrink_forward(input, lambd)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.shrink_backward(doutput, compressed_mask)
return dinput, None
[docs]@_cuda_only
@_disable_set_replace_torch
def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor:
"""
hardshrink 函数, 参考 :class:`~hfai.nn.Hardshrink`
"""
if has_torch_function((input,)):
return handle_torch_function(hardshrink, (input,), input, lambd=lambd)
if not torch.is_grad_enabled() or not input.requires_grad:
return torch.hardshrink(input, lambd)
return HardshrinkFunc.apply(input, lambd)
class SoftshrinkFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, lambd):
input = input.contiguous()
output, compressed_mask = hf_bitmask.softshrink_forward(input, lambd)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask = ctx.saved_tensors[0]
dinput = hf_bitmask.shrink_backward(doutput, compressed_mask)
return dinput, None
[docs]@_cuda_only
def softshrink(input: Tensor, lambd: float = 0.5) -> Tensor:
"""
softshrink 函数, 参考 :class:`~hfai.nn.Softshrink`
"""
if has_torch_function((input,)):
return handle_torch_function(softshrink, (input,), input, lambd=lambd)
if not torch.is_grad_enabled() or not input.requires_grad:
return torch._C._nn.softshrink(input, lambd)
return SoftshrinkFunc.apply(input, lambd)
class AbsFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, inplace):
if inplace:
ctx.mark_dirty(input)
input = input.contiguous()
output, compressed_mask, zero_mask = hf_bitmask.abs_forward(input, inplace)
ctx.save_for_backward(compressed_mask, zero_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
doutput = doutput.contiguous()
compressed_mask, zero_mask = ctx.saved_tensors
dinput = hf_bitmask.abs_backward(doutput, compressed_mask, zero_mask)
return dinput, None
[docs]@_cuda_only
@_disable_set_replace_torch
def abs(input: Tensor) -> Tensor:
"""
压位 abs 函数, 用法与 func:`torch.abs` 一致
"""
if not torch.is_grad_enabled() or not input.requires_grad:
return input.abs()
return AbsFunc.apply(input, False)
[docs]@_cuda_only
@_disable_set_replace_torch
def abs_(input: Tensor) -> Tensor:
"""
原地操作的 abs
"""
if not torch.is_grad_enabled() or not input.requires_grad:
return input.abs_()
if not input.is_contiguous():
return input.abs_()
return AbsFunc.apply(input, True)
class MaximumFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2):
output, compressed_mask1, compressed_mask2 = hf_bitmask.maximum_forward(input1, input2)
ctx.save_for_backward(compressed_mask1, compressed_mask2)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask1, compressed_mask2 = ctx.saved_tensors
dinput1, dinput2 = hf_bitmask.max_min_backward(doutput, compressed_mask1, compressed_mask2)
return dinput1, dinput2
[docs]@_cuda_only
@_disable_set_replace_torch
def maximum(input1: Tensor, input2: Tensor) -> Tensor:
"""
压位 maximum 函数, 用法与 func:`torch.maximum` 一致
"""
if (not torch.is_grad_enabled()) or (input1.dtype != input2.dtype) or (
not input1.requires_grad and not input2.requires_grad):
return input1.maximum(input2)
x, y = torch.broadcast_tensors(input1, input2)
expend_numel = x.numel()
if input1.shape.numel() > input2.shape.numel():
max_numel = input1.shape.numel()
else:
max_numel = input2.shape.numel()
if (expend_numel > max_numel * 32):
return input1.maximum(input2)
return MaximumFunc.apply(x, y)
class MinimumFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input1, input2):
output, compressed_mask1, compressed_mask2 = hf_bitmask.minimum_forward(input1, input2)
ctx.save_for_backward(compressed_mask1, compressed_mask2)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask1, compressed_mask2 = ctx.saved_tensors
dinput1, dinput2 = hf_bitmask.max_min_backward(doutput, compressed_mask1, compressed_mask2)
return dinput1, dinput2
[docs]@_cuda_only
@_disable_set_replace_torch
def minimum(input1: Tensor, input2: Tensor) -> Tensor:
"""
压位 minimum 函数, 用法与 func:`torch.minimum` 一致
"""
if (not torch.is_grad_enabled()) or (input1.dtype != input2.dtype) or (
not input1.requires_grad and not input2.requires_grad):
return input1.minimum(input2)
x, y = torch.broadcast_tensors(input1, input2)
expend_numel = x.numel()
if input1.shape.numel() > input2.shape.numel():
max_numel = input1.shape.numel()
else:
max_numel = input2.shape.numel()
if (expend_numel > max_numel * 32):
return input1.minimum(input2)
return MinimumFunc.apply(x, y)
class WhereFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, condition, input1, input2):
output = input1.where(condition, input2)
compressed_mask = hf_bitmask.compress_mask(condition)
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask = ctx.saved_tensors[0]
dinput1, dinput2 = hf_bitmask.where_backward(doutput, compressed_mask)
return None, dinput1, dinput2
[docs]@_cuda_only
def where(condition: Tensor, input1: Optional[Union[Tensor, Number]] = None,
input2: Optional[Union[Tensor, Number]] = None) -> Union[Tensor, Tuple[Tensor]]:
"""
压位 where 函数, 用法与 func:`torch.where` 一致
"""
if input1 is None and input2 is None:
return torch.nonzero(condition, as_tuple=True)
if not isinstance(input1, Tensor):
input1 = torch.tensor(input1, device=condition.device)
if not isinstance(input2, Tensor):
input2 = torch.tensor(input2, device=condition.device)
if not torch.is_grad_enabled() or (not input1.requires_grad and not input2.requires_grad):
return input1.where(condition, input2)
x, y, z = torch.broadcast_tensors(condition, input1, input2)
if x.shape.numel() > condition.shape.numel() * 32:
return input1.where(condition, input2)
return WhereFunc.apply(x, y, z)
class MaskedFillFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, value, inplace):
if inplace:
ctx.mark_dirty(input)
output = input.masked_fill_(mask, value)
else:
output = input.masked_fill(mask, value)
compressed_mask = hf_bitmask.compress_mask(mask)
value_need_grad = False
if isinstance(value, Tensor) and value.requires_grad:
value_need_grad = True
ctx.value_need_grad = value_need_grad
ctx.save_for_backward(compressed_mask)
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask = ctx.saved_tensors[0]
value_need_grad = ctx.value_need_grad
dinput = hf_bitmask.masked_fill_backward(doutput, compressed_mask)
dvalue = None
if value_need_grad:
dvalue = doutput.sum() - dinput.sum()
return dinput, None, dvalue, None
[docs]@_cuda_only
@_disable_set_replace_torch
def masked_fill(input: Tensor, mask: Tensor, value: Union[Number, Tensor]) -> Tensor:
"""
压位 masked_fill 函数, 用法与 func:`torch.masked_fill` 一致
"""
if not torch.is_grad_enabled() or not input.requires_grad:
return input.masked_fill(mask, value)
x, y = torch.broadcast_tensors(input, mask)
if y.shape.numel() > mask.shape.numel() * 32:
return input.masked_fill(mask, value)
return MaskedFillFunc.apply(input, y, value, False)
[docs]@_cuda_only
@_disable_set_replace_torch
def masked_fill_(input: Tensor, mask: Tensor, value: Union[Number, Tensor]) -> Tensor:
"""
原地操作的压位 masked_fill 函数
"""
if not torch.is_grad_enabled() or not input.requires_grad:
return input.masked_fill_(mask, value)
x, y = torch.broadcast_tensors(input, mask)
if y.shape.numel() > mask.shape.numel() * 32:
return input.masked_fill_(mask, value)
return MaskedFillFunc.apply(input, y, value, True)
class MaskedSelectFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask):
output = input.masked_select(mask)
compressed_mask = hf_bitmask.compress_mask(mask)
ctx.save_for_backward(compressed_mask)
ctx.n = mask.numel()
ctx.input_options = {'size': input.size(), 'dtype': input.dtype, 'layout': input.layout, 'device': input.device}
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask = ctx.saved_tensors[0]
mask = hf_bitmask.decompress_mask(compressed_mask, ctx.n)
input_options = ctx.input_options
dinput = torch.zeros(size=input_options['size'], dtype=input_options['dtype'],
layout=input_options['layout'], device=input_options['device'])
dinput.masked_scatter_(mask, doutput)
return dinput, None
[docs]@_cuda_only
@_disable_set_replace_torch
def masked_select(input: Tensor, mask: Tensor) -> Tensor:
"""
节省显存的 masked_select 函数, 用法与 func:`torch.masked_select` 一致
"""
if not torch.is_grad_enabled() or not input.requires_grad:
return input.masked_select(mask)
x, y = torch.broadcast_tensors(input, mask)
if y.shape.numel() > mask.shape.numel() * 32:
return input.masked_select(mask)
return MaskedSelectFunc.apply(x, y)
class MaskedScatterFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, mask, source, inplace):
if inplace:
ctx.mark_dirty(input)
output = input.masked_scatter_(mask, source)
else:
output = input.masked_scatter(mask, source)
compressed_mask = hf_bitmask.compress_mask(mask)
ctx.save_for_backward(compressed_mask)
ctx.n = mask.numel()
ctx.source_size = source.size()
ctx.source_numel = source.numel()
ctx.input_requires = input.requires_grad
ctx.source_requires = source.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, doutput):
compressed_mask = ctx.saved_tensors[0]
mask = hf_bitmask.decompress_mask(compressed_mask, ctx.n)
if ctx.input_requires:
dinput = doutput.masked_fill(mask, 0)
else:
dinput = None
if ctx.source_requires:
mask_selected = doutput.masked_select(mask)
diff_nelem = ctx.source_numel - mask_selected.numel()
if diff_nelem > 0:
# because mask_selected returns a 1-d tensor with size of masked elements that are 1,
# we need to fill out ther rest with zeros then reshape back to tensor2's size.
mask_selected = torch_F.pad(mask_selected, [0, diff_nelem], value=0)
dsource = mask_selected.view(ctx.source_size)
else:
dsource = None
return dinput, None, dsource, None
[docs]@_cuda_only
@_disable_set_replace_torch
def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor:
"""
节省显存的 masked_scatter 函数, 用法与 func:`torch.masked_scatter` 一致
"""
if not torch.is_grad_enabled() or (not input.requires_grad and not source.requires_grad):
return input.masked_scatter(mask, source)
x, y = torch.broadcast_tensors(input, mask)
if y.shape.numel() > mask.shape.numel() * 32:
return input.masked_scatter(mask, source)
return MaskedScatterFunc.apply(input, y, source, False)
[docs]@_cuda_only
@_disable_set_replace_torch
def masked_scatter_(input: Tensor, mask: Tensor, source: Tensor) -> Tensor:
"""
原地操作的节省显存的 masked_scatter 函数
"""
if not torch.is_grad_enabled() or (not input.requires_grad and not source.requires_grad):
return input.masked_scatter_(mask, source)
x, y = torch.broadcast_tensors(input, mask)
if y.shape.numel() > mask.shape.numel() * 32:
return input.masked_scatter_(mask, source)
return MaskedScatterFunc.apply(input, y, source, True)
def canonicalize_axis(dim: int, num_dims: int) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
if not -num_dims <= dim < num_dims:
raise ValueError(f"Dim {dim} is out of bounds for array of dimension {num_dims}")
if dim < 0:
dim = dim + num_dims
return dim
X = TypeVar('X')
Y = TypeVar('Y')
Hidden = TypeVar('Hidden')
[docs]def scan(f: Callable[[X, Hidden], Tuple[Y, Hidden]],
xs: X,
init: Hidden,
dim: int = 0,
reverse: bool = False) -> Tuple[Y, Hidden]:
"""
用给定函数在数据上扫描 (类似 RNN), 并在相邻的阶段间传递隐藏状态 ``hidden``
Args:
f (Callable[Tuple[X, Hidden], Tuple[Y, Hidden]]): 每个阶段执行的函数, 形如 ``f: (x, hidden_in) -> (y, hidden_out)``.
``X`` 和 ``Y`` 都是 ``Tensor`` 或者叶子节点是 ``Tensor`` 的 pytree (嵌套的 tuple/list/dict).
``i`` 阶段的 ``hidden_out`` 是 ``i + 1`` 阶段的 ``hidden_in``
xs (X): 在主维堆叠在一起的 ``x``. 阶段数 (``f`` 的执行次数) 等于主维大小
init (Hidden): 第一个阶段输入的 ``hidden_in``
dim (int, optional): 主维. 默认: 0
reverse (bool, optional): 是否逆序扫描 ``xs``. 默认: False
Returns:
一个元组 ``(ys, h)``, ``ys`` 表示扫描结果; ``h`` 是最后一个阶段输出的 ``hidden_out``
Examples:
.. code-block:: python
f = lambda x, y: (x + y, x + y)
a = torch.tensor([1, 2, 3, 4])
sum = hfai.nn.functional.scan(f, a, torch.tensor(0)) # (tensor([1, 3, 6, 10]), tensor(10))
sum = hfai.nn.functional.scan(f, a, torch.tensor(0), reverse=True) # (tensor([10, 9, 7, 4]), tensor(10))
.. code-block:: python
f = lambda x, y: ((x[0] + y, x[1] + y), x[0] + x[1] + y)
a = [torch.tensor([1, 2, 3, 4]), torch.tensor([4, 3, 2, 1])]
sum = hfai.nn.functional.scan(f, a, torch.tensor(0)) # ([tensor([1, 7, 13, 19]), tensor([4, 8, 12, 16])], tensor(20))
sum = hfai.nn.functional.scan(f, a, torch.tensor(0), reverse=True) # ([tensor([16, 12, 8, 4]), tensor([19, 13, 7, 1])], tensor(20))
"""
if not callable(f):
raise TypeError(f"f 必须是一个函数")
xs_flat, tree_x = pytree.tree_flatten(xs)
if reverse:
xs_flat = [torch.flip(x, [dim]) for x in xs_flat]
# Check that all inputs have a consistent leading dimension `num_xs`.
dim = canonicalize_axis(dim, xs_flat[0].ndim)
num_xs = int(xs_flat[0].shape[dim])
if not all(int(x.shape[dim]) == num_xs for x in xs_flat[1:]):
raise ValueError('scan 的 xs 的主维必须相同, (当前是: {})'.format([x.shape for x in xs_flat]))
hidden = init
is_ys_init = False
for i in range(num_xs):
x_flat = [torch.select(x_flat, dim=dim, index=i) for x_flat in xs_flat]
y, hidden = f(pytree.tree_unflatten(x_flat, tree_x), hidden)
y_flat, tree_y = pytree.tree_flatten(y)
if not is_ys_init:
is_ys_init = True
ys_flat = [[y] for y in y_flat]
else:
for ys, y in zip(ys_flat, y_flat):
ys.append(y)
ys_flat = list(map(partial(torch.stack, dim=dim), ys_flat))
if reverse:
ys_flat = [torch.flip(ys, [dim]) for ys in ys_flat]
return pytree.tree_unflatten(ys_flat, tree_y), hidden
[docs]def associative_scan(f: Callable[[X, X], X],
xs: X,
dim: int = 0,
reverse: bool = False):
"""
用满足结合律的二元运算函数在数据上扫描 (类似前缀和), 并行执行
Args:
f (Callable[Tuple[X, X], X]): 二元运算函数, 需要满足结合律, 即 ``f(f(a, b), c) = f(a, f(b, c))``.
``X`` 是 ``Tensor`` 或者叶子节点是 ``Tensor`` 的 pytree (嵌套的 tuple/list/dict).
``f`` 的输入和输出必须结构相同 (如果 ``X`` 是 ``Tensor``, 那么输入和输出必须 ``shape`` 相同;
如果 ``X`` 是 ``pytree``, 那么输入和输出的 ``pytree`` 必须同构, 并且对应的叶子节点的 ``Tensor`` 必须 ``shape`` 相同)
xs (X): 在主维堆叠在一起的 ``x``.
dim (int, optional): 主维. 默认: 0
reverse (bool, optional): 是否逆序扫描 ``xs``. 默认: False
Returns:
``ys``, 表示扫描结果, 与 ``xs`` 结构相同
Examples:
.. code-block:: python
f = lambda x, y: x + y
a = torch.tensor([1, 2, 3, 4])
sum = hfai.nn.functional.associative_scan(f, a) # tensor([1, 3, 6, 10])
sum = hfai.nn.functional.associative_scan(f, a, reverse=True) # tensor([10, 9, 7, 4])
.. code-block:: python
f = lambda x, y: (y[0] * x[0], y[0] * x[1] + y[1]) # 满足结合律
a = [torch.randn(5, 6, 7), torch.randn(5, 6, 7)]
sum = hfai.nn.functional.associative_scan(f, a, dim=-1, reverse=True)
"""
if not callable(f):
raise TypeError(f"f 必须是一个函数")
xs_flat, tree = pytree.tree_flatten(xs)
if reverse:
xs_flat = [torch.flip(x, [dim]) for x in xs_flat]
# Check that all inputs have a consistent leading dimension `num_xs`.
dim = canonicalize_axis(dim, xs_flat[0].ndim)
num_xs = int(xs_flat[0].shape[dim])
if not all(int(x.shape[dim]) == num_xs for x in xs_flat[1:]):
raise ValueError('associative_scan 的 xs 的主维必须相同, (当前是: {})'.format([x.shape for x in xs_flat]))
def _combine(a_flat, b_flat):
a = pytree.tree_unflatten(a_flat, tree)
b = pytree.tree_unflatten(b_flat, tree)
c = f(a, b)
c_flat, _ = pytree.tree_flatten(c)
return c_flat
def _slice_in_dim(start, limit, stride=1):
slices = [slice(0, None)] * (dim + 1)
slices[dim] = slice(start, limit, stride)
return slices
def _scan(xs):
num_xs = xs[0].shape[dim]
if num_xs <= 1:
return xs
# Combine adjacent pairs of elements.
reduced_xs = _combine([x[_slice_in_dim(0, -1, 2)] for x in xs],
[x[_slice_in_dim(1, None, 2)] for x in xs])
# Recursively compute scan for partially reduced tensors.
odd_xs = _scan(reduced_xs)
if num_xs % 2 == 0:
even_xs = _combine([odd_x[_slice_in_dim(0, -1)] for odd_x in odd_xs],
[x[_slice_in_dim(2, None, 2)] for x in xs])
else:
even_xs = _combine(odd_xs,
[x[_slice_in_dim(2, None, 2)] for x in xs])
ys = [torch.empty_like(x) for x in xs]
for (x, odd_x, even_x, y) in zip(xs, odd_xs, even_xs, ys):
y[_slice_in_dim(0, 1)] = x[_slice_in_dim(0, 1)]
y[_slice_in_dim(1, None, 2)] = odd_x
y[_slice_in_dim(2, None, 2)] = even_x
return ys
ys_flat = _scan(xs_flat)
if reverse:
ys_flat = [torch.flip(y, [dim]) for y in ys_flat]
return pytree.tree_unflatten(ys_flat, tree)
# Copyright 2020 Google LLC
# Copyright 2021 Teddy Koker
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
no_softrank = False
try:
import hfai.hfcuda.softrank as softrank
except:
no_softrank = True
def soft_rank(values, regularization="l2", regularization_strength=1.0):
if len(values.shape) != 2:
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
if regularization not in ["l2", "kl"]:
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
return SoftRank.apply(values, regularization, regularization_strength)
def soft_sort(values, regularization="l2", regularization_strength=1.0):
if len(values.shape) != 2:
raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}")
if regularization not in ["l2", "kl"]:
raise ValueError(f"'regularization' should be a 'l2' or 'kl'")
return SoftSort.apply(values, regularization, regularization_strength)
def _arange_like(x, reverse=False):
# returns arange with len of x of the same dtype and device (assumes 2d, first dim batch)
if reverse:
ar = torch.arange(x.shape[1] - 1, -1, -1, dtype=x.dtype, device=x.device)
else:
ar = torch.arange(x.shape[1], dtype=x.dtype, device=x.device)
return ar.expand(x.shape[0], -1)
def _inv_permutation(permutation):
# returns inverse permutation of 'permutation'. (assumes 2d, first dim batch)
inv_permutation = torch.zeros_like(permutation)
inv_permutation.scatter_(1, permutation, _arange_like(permutation))
return inv_permutation
# The following is from google-research/fast-soft-sort with the following modifications:
# - replace numpy functions with torch equivalent
# - remove uncessary operations
# - reimplement backward pass in C++
class SoftRank(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, regularization="l2", regularization_strength=1.0):
ctx.scale = 1.0 / regularization_strength
ctx.regularization = regularization
w = _arange_like(tensor, reverse=True) + 1
theta = tensor * ctx.scale
s, permutation = torch.sort(theta, descending=True)
inv_permutation = _inv_permutation(permutation)
if ctx.regularization == "l2":
dual_sol = softrank.isotonic_l2(s - w)
ret = (s - dual_sol).gather(1, inv_permutation)
factor = torch.tensor(1.0, device=s.device)
else:
dual_sol = softrank.isotonic_kl(s, torch.log(w))
ret = torch.exp((s - dual_sol).gather(1, inv_permutation))
factor = ret
ctx.save_for_backward(factor, s, dual_sol, permutation, inv_permutation)
return ret
@staticmethod
def backward(ctx, grad_output):
factor, s, dual_sol, permutation, inv_permutation = ctx.saved_tensors
grad = (grad_output * factor).clone()
if ctx.regularization == "l2":
grad -= softrank.isotonic_l2_backward(
s, dual_sol, grad.gather(1, permutation)
).gather(1, inv_permutation)
else:
grad -= softrank.isotonic_kl_backward(
s, dual_sol, grad.gather(1, permutation)
).gather(1, inv_permutation)
return grad * ctx.scale, None, None
class SoftSort(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, regularization="l2", regularization_strength=1.0):
ctx.sign = -1
ctx.regularization = regularization
w = (_arange_like(tensor, reverse=True) + 1) / regularization_strength
tensor = ctx.sign * tensor # for ascending
s, permutation = torch.sort(tensor, descending=True)
# note reverse order of args
if ctx.regularization == "l2":
sol = softrank.isotonic_l2(w - s)
else:
sol = softrank.isotonic_kl(w, s)
ctx.save_for_backward(s, sol, permutation)
return ctx.sign * (w - sol)
@staticmethod
def backward(ctx, grad_output):
s, sol, permutation = ctx.saved_tensors
inv_permutation = _inv_permutation(permutation)
if ctx.regularization == "l2":
grad = softrank.isotonic_l2_backward(s, sol, grad_output)
else:
grad = softrank.isotonic_kl_backward(s, sol, grad_output)
return grad.gather(1, inv_permutation), None, None
try:
import hfai.hfcuda.linear as hf_linear
except:
pass
def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
if has_torch_function((input,)):
return handle_torch_function(linear, (input,), input, weight=weight, bias=bias)
if bias is None or not torch.backends.cuda.matmul.allow_tf32:
return torch_F.linear(input, weight, bias)
return hf_linear.linear(input, weight, bias)
try:
import hfai.hfcuda.swiglu as hf_swiglu
except:
pass
class SwiGLUFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input, dim):
ctx.save_for_backward(input)
ctx.dim = dim
output = hf_swiglu.swiglu_forward(input, dim)
return output
@staticmethod
def backward(ctx, doutput):
input = ctx.saved_tensors[0]
dim = ctx.dim
dinput = hf_swiglu.swiglu_backward(doutput, input, dim)
return dinput, None
def swiglu(input: Tensor, dim=-1) -> Tensor:
return SwiGLUFunction.apply(input, dim)
[docs]class set_replace_torch(object):
"""
把所有 hfai 优化过的 torch.nn.functional 和 torch 中的函数转换为 hfai.nn.functional 的对应函数.
如: torch.nn.functional.relu -> hfai.nn.functional.relu, torch.max -> hfai.nn.funtional.max, x.abs() -> hfai.nn.funtional.abs(x)
Args:
mode (bool, optional): 是否开启替换. 默认: True
.. note::
调用 ``hfai.nn.functional.set_replace_torch()`` 后, 无论是 (1)用户显示调用 还是 (2)PyTorch内部调用,
一切对 ``torch.nn.functional.xxx`` 的调用, 都会执行 ``hfai.nn.functional.xxx`` .
例如 ``torch.nn.CrossEntropyLoss`` 中的 ``log_softmax`` 会自动执行 ``hfai.nn.functional.log_softmax``
Examples:
.. code-block:: python
hfai.nn.functional.set_replace_torch()
y = torch.nn.functional.softmax(x) # softmax 执行 hfai 的实现
hfai.nn.functional.set_replace_torch(False)
with hfai.nn.functional.set_replace_torch():
loss = torch.nn.functional.cross_entropy(input, target) # 内部的 log_softmax 执行 hfai 的实现
"""
# 所有可以无缝替换 torch.nn.functional.xxx 的函数
replaceable_nn_functions = \
[dropout, relu, relu_, hardtanh, hardtanh_, relu6, softplus, softmin, softmax, log_softmax,
_threshold, threshold, threshold_, rrelu, rrelu_, leaky_relu, leaky_relu_, hardsigmoid, hardshrink, softshrink]
replaceable_nn_functions_torch = [getattr(torch_F, func.__name__) for func in replaceable_nn_functions]
# 所有可以无缝替换 torch.xxx 的函数
replaceable_torch_functions = \
[minimum, maximum, abs, abs_, min, max,
clip, clip_, clamp, clamp_, clamp_max, clamp_max_, clamp_min, clamp_min_, where,
masked_fill, masked_select, masked_scatter]
replaceable_torch_functions_torch = [getattr(torch, func.__name__) for func in replaceable_torch_functions]
# 所有可以无缝替换 torch.Tensor.xxx 的函数
replaceable_tensor_functions = \
[relu, relu_, softmax, log_softmax, hardshrink, minimum, maximum, abs, abs_, min, max,
clip, clip_, clamp, clamp_, clamp_max, clamp_max_, clamp_min, clamp_min_,
masked_fill, masked_fill_, masked_select, masked_scatter, masked_scatter_]
replaceable_tensor_functions_torch = [getattr(torch.Tensor, func.__name__) for func in replaceable_tensor_functions]
is_replacement_enabled = False
@staticmethod
def _set(cls, mode: bool):
cls.is_replacement_enabled = mode
for func in (cls.replaceable_nn_functions if mode is True else cls.replaceable_nn_functions_torch):
setattr(torch_F, func.__name__, func)
for func in (cls.replaceable_torch_functions if mode is True else cls.replaceable_torch_functions_torch):
setattr(torch, func.__name__, func)
for func in (cls.replaceable_tensor_functions if mode is True else cls.replaceable_tensor_functions_torch):
setattr(torch.Tensor, func.__name__, func)
def __init__(self, mode: bool = True):
self.prev = self.__class__.is_replacement_enabled
self._set(self.__class__, mode)
def __enter__(self):
pass
def __exit__(self, exc_type, exc_val, exc_tb):
self._set(self.__class__, self.prev)