Shortcuts

Source code for hfai.nn.functional

import torch
import warnings
from torch import Tensor
from typing import Optional, Tuple, List, Union, Callable, TypeVar
from torch.types import Number
from torch import _VF
from torch.overrides import has_torch_function, handle_torch_function
import torch.nn.functional as torch_F
from torch.nn.functional import *
from .sync_function import sync
from .functional_utils import _cuda_only, _disable_set_replace_torch
from . import _pytree as pytree
from functools import partial

try:
    import hfai.hfcuda.dropout as hf_dropout
except:
    pass


class DropoutFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, p, inplace):
        if inplace:
            ctx.mark_dirty(input)
        input = input.contiguous()
        output, seeds = hf_dropout.dropout_forward(input, p, inplace)
        ctx.p = p
        ctx.seeds = seeds
        return output

    @staticmethod
    @torch.autograd.function.once_differentiable
    def backward(ctx, doutput):
        doutput = doutput.contiguous()
        dinput = hf_dropout.dropout_backward(doutput, ctx.p, ctx.seeds)
        ret = dinput, None, None
        return ret


[docs]@_cuda_only def dropout(input: Tensor, p: float = 0.5, training: bool = True, inplace: bool = False) -> Tensor: """ dropout 函数, 参考 :class:`~hfai.nn.Dropout` """ if has_torch_function((input,)): return handle_torch_function(dropout, (input,), input, p=p, training=training, inplace=inplace) if p < 0. or p > 1.: raise ValueError("dropout probability has to be between 0 and 1, " "but got {}".format(p)) if not training: return input if p == 0: return input if p == 1: return input.zero_() if inplace else torch.zeros_like(input) if not torch.is_grad_enabled() or not input.requires_grad: return _VF.dropout_(input, p, training) if inplace else _VF.dropout(input, p, training) if inplace and not input.is_contiguous(): return _VF.dropout_(input, p, training) return DropoutFunction.apply(input, p, inplace)
try: import hfai.hfcuda.bitmask as hf_bitmask except: pass class ReLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask = hf_bitmask.relu_forward(input, inplace) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.relu_backward(doutput, compressed_mask) return dinput, None
[docs]@_cuda_only @_disable_set_replace_torch def relu(input: Tensor, inplace: bool = False) -> Tensor: """ relu 函数, 参考 :class:`~hfai.nn.ReLU` """ if has_torch_function((input,)): return handle_torch_function(relu, (input,), input, inplace=inplace) if not torch.is_grad_enabled() or not input.requires_grad: return input.relu_() if inplace else input.relu() if inplace and not input.is_contiguous(): return input.relu_() return ReLUFunction.apply(input, inplace)
[docs]@_cuda_only @_disable_set_replace_torch def relu_(input: Tensor) -> Tensor: """ 原地操作的 relu 函数, 参考 :class:`~hfai.nn.ReLU` """ return relu(input, inplace=True)
[docs]def hardtanh(input: Tensor, min_val: float = -1., max_val: float = 1., inplace: bool = False) -> Tensor: """ hardtanh 函数, 参考 :class:`~hfai.nn.Hardtanh` """ if has_torch_function((input,)): return handle_torch_function(hardtanh, (input,), input, min_val=min_val, max_val=max_val, inplace=inplace) return clamp(input, min_val, max_val, inplace)
[docs]def hardtanh_(input: Tensor, min_val: float = -1., max_val: float = 1.) -> Tensor: """ 原地操作的 hardtanh 函数, 参考 :class:`~hfai.nn.Hardtanh` """ if has_torch_function((input,)): return handle_torch_function(hardtanh_, (input,), input, min_val=min_val, max_val=max_val) return clamp_(input, min_val, max_val)
[docs]def relu6(input: Tensor, inplace: bool = False) -> Tensor: """ relu6 函数, 参考 :class:`~hfai.nn.ReLU6` """ return hardtanh(input, 0, 6, inplace)
try: import hfai.hfcuda.softplus as hf_softplus except: pass class SoftplusFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, beta, threshold): output = hf_softplus.softplus_forward(input, beta, threshold) ctx.beta = beta ctx.threshold = threshold ctx.save_for_backward(input) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): input = ctx.saved_tensors[0] dinput = hf_softplus.softplus_backward(doutput, input, ctx.beta, ctx.threshold) return dinput, None, None
[docs]@_cuda_only def softplus(input: Tensor, beta: int = 1, threshold: int = 20) -> Tensor: """ softplus 函数, 参考 :class:`~hfai.nn.Softplus` """ if has_torch_function((input,)): return handle_torch_function(softplus, (input,), input, beta=beta, threshold=threshold) if not torch.is_grad_enabled() or not input.requires_grad: return torch._C._nn.softplus(input, beta, threshold) return SoftplusFunction.apply(input, beta, threshold)
def _get_softmax_dim(name: str, ndim: int, stacklevel: int) -> int: warnings.warn("Implicit dimension choice for {} has been deprecated. " "Change the call to include dim=X as an argument.".format(name), stacklevel=stacklevel) if ndim == 0 or ndim == 1 or ndim == 3: ret = 0 else: ret = 1 return ret try: import hfai.hfcuda.softmax as hf_softmax except: pass class SoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, dim): output = hf_softmax.softmax_forward(input, dim) ctx.dim = dim ctx.input_dtype = input.dtype ctx.save_for_backward(output) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): output = ctx.saved_tensors[0] fake_input = torch.empty([1], dtype=ctx.input_dtype, device=output.device) dinput = hf_softmax.softmax_backward(doutput, output, ctx.dim, fake_input) return dinput, None
[docs]def softmin(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: """ softmin 函数, 参考 :class:`~hfai.nn.Softmin` """ return softmax(-input, dim, _stacklevel, dtype)
[docs]@_disable_set_replace_torch def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: """ softmax 函数, 参考 :class:`~hfai.nn.Softmax` """ if has_torch_function((input,)): return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("softmax", input.dim(), _stacklevel) if dtype is not None: input = input.type(dtype) if not torch.is_grad_enabled() or not input.requires_grad: return input.softmax(dim) return SoftmaxFunction.apply(input, dim)
class LogSoftmaxFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, dim): output = hf_softmax.log_softmax_forward(input, dim) ctx.dim = dim ctx.input_dtype = input.dtype ctx.save_for_backward(output) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): output = ctx.saved_tensors[0] fake_input = torch.empty([1], dtype=ctx.input_dtype, device=output.device) dinput = hf_softmax.log_softmax_backward(doutput, output, ctx.dim, fake_input) return dinput, None
[docs]@_disable_set_replace_torch def log_softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> Tensor: """ log_softmax 函数, 参考 :class:`~hfai.nn.LogSoftmax` """ if has_torch_function((input,)): return handle_torch_function(log_softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype) if dim is None: dim = _get_softmax_dim("log_softmax", input.dim(), _stacklevel) if dtype is not None: input = input.type(dtype) if not torch.is_grad_enabled() or not input.requires_grad: return input.log_softmax(dim) return LogSoftmaxFunction.apply(input, dim)
class ClampFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, min, max, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask = hf_bitmask.clamp_forward(input, min, max, inplace) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.clamp_backward(doutput, compressed_mask) return dinput, None, None, None
[docs]def clamp_min(input: Tensor, min: float, inplace: bool = False) -> Tensor: """ 压位 clamp_min 算子, 训练时的中间结果用 1bit 储存 `[x >= min]`, 以节省训练时的内存 Args: input: 输入的 Tensor min (float): output 的最小值 inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False`` .. code-block:: python import hfai.nn.functional as F y = F.clamp_min(x, min=-0.5) # same as: y = x.clamp_min(min=-0.5) # same as: y = torch.max(x, -0.5 * torch.ones_like(x)) """ return clamp(input, min=min, inplace=inplace)
[docs]def clamp_min_(input: Tensor, min: float) -> Tensor: """ 原地操作的 clamp_min 函数 """ return clamp_min(input, min=min, inplace=True)
[docs]def clamp_max(input: Tensor, max: float, inplace: bool = False) -> Tensor: """ 压位 clamp_max 算子, 训练时的中间结果用 1bit 储存 `[x <= max]`, 以节省训练时的内存 Args: input: 输入的 Tensor max (float): output 的最大值 inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False`` .. code-block:: python import hfai.nn.functional as F y = F.clamp_max(x, max=0.5) # same as: y = x.clamp_max(max=0.5) # same as: y = torch.min(x, 0.5 * torch.ones_like(x)) """ return clamp(input, max=max, inplace=inplace)
[docs]def clamp_max_(input: Tensor, max: float) -> Tensor: """ 原地操作的 clamp_max 函数 """ return clamp(input, max=max, inplace=True)
[docs]@_cuda_only @_disable_set_replace_torch def clamp(input: Tensor, min: Optional[float] = None, max: Optional[float] = None, inplace: bool = False) -> Tensor: """ 压位 clamp 算子, 训练时的中间结果用 1bit 储存 `[min <= x <= max]`, 以节省训练时的内存 Args: input: 输入的 Tensor min (float): output 的最小值, 默认: ``None`` max (float): output 的最大值, 默认: ``None`` inplace (bool, optional): 如果是 ``True``, 进行原地操作, 默认: ``False`` .. code-block:: python import hfai.nn.functional as F y = F.clamp(x, min=-0.5, max=-0.5) # same as: y = x.clamp(min=-0.5, max=0.5) """ if has_torch_function((input,)): return handle_torch_function(clamp, (input,), input, min=min, max=max, inplace=inplace) if min is None and max is None: raise RuntimeError("At least one of 'min' or 'max' must not be None") if min is None: min = float('-inf') if max is None: max = float('inf') if not torch.is_grad_enabled() or not input.requires_grad: return input.clamp_(min, max) if inplace else input.clamp(min, max) if inplace and not input.is_contiguous(): return input.clamp_(min, max) return ClampFunc.apply(input, min, max, inplace)
[docs]@_cuda_only @_disable_set_replace_torch def clamp_(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor: """ 原地操作的 clamp 函数 """ return clamp(input, min, max, inplace=True)
[docs]def clip(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor: """ clip 函数, 参考 :func:`hfai.nn.functional.clamp` """ return clamp(input, min, max)
[docs]def clip_(input: Tensor, min: Optional[float] = None, max: Optional[float] = None) -> Tensor: """ 原地操作的 clip 函数 """ return clamp_(input, min, max)
class ThresholdFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, threshold, value, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask = hf_bitmask.threshold_forward(input, threshold, value, inplace) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.threshold_backward(doutput, compressed_mask) return dinput, None, None, None @_cuda_only def _threshold(input: Tensor, threshold: float, value: float, inplace: bool = False) -> Tensor: """ threshold 函数, 参考 :class:`~hfai.nn.Threshold` """ if has_torch_function((input,)): return handle_torch_function(_threshold, (input,), input, threshold=threshold, value=value, inplace=inplace) if not torch.is_grad_enabled() or not input.requires_grad: return _VF.threshold_(input, threshold, value) if inplace else _VF.threshold(input, threshold, value) if inplace and not input.is_contiguous(): return _VF.threshold_(input, threshold, value) return ThresholdFunc.apply(input, threshold, value, inplace) # We define this function as _threshold because it takes an argument # named threshold, which clobbers the recursive reference to the # function needed for __torch_function__ support threshold = _threshold
[docs]def threshold_(input: Tensor, threshold: float, value: float) -> Tensor: """ 原地操作的 threshold 函数, 参考 :class:`~hfai.nn.Threshold` """ return _threshold(input, threshold, value, inplace=True)
[docs]@_disable_set_replace_torch def max(input: Tensor, value: Optional[Union[torch.Tensor, int, float]] = None, dim: Optional[int] = None, keepdim: bool = False) -> Tensor: """ 压位 max 算子, 返回 input 和 max 中的较大值 若 value 是 Tensor, 调用 hf_F.maximum(input, value) 若 value 是 float, 调用 hf_F.clamp(input, min=value) 若 value 是 int, 或者 dim 或 keepdim 不为 None, 调用 torch.max(input, dim=value or dim, keepdim=keepdim) Args: input (Tensor): 输入的 Tensor value (Tensor or float or int): 进行比较的值或做 max 操作的维度 dim (int): 做操作的维度 keepdim (bool): 是否 keepdim """ if value is None and dim is None: return input.max() if dim is not None: return input.max(dim=dim, keepdim=keepdim) if isinstance(value, int): return input.max(dim=value, keepdim=keepdim) if isinstance(value, torch.Tensor): return maximum(input, value) return clamp(input, min=value)
[docs]@_disable_set_replace_torch def min(input: Tensor, value: Optional[Union[torch.Tensor, int, float]] = None, dim: Optional[int] = None, keepdim: bool = False) -> Tensor: """ 压位 min 算子, 返回 input 和 min 中的较大值 若 value 是 Tensor, 调用 hf_F.minimum(input, value) 若 value 是 float, 调用 hf_F.clamp(input, max=value) 若 value 是 int, 或者 dim 或 keepdim 不为 None, 调用 torch.min(input, dim=value or dim, keepdim=keepdim) Args: input (Tensor): 输入的 Tensor value (Tensor or float or int): 进行比较的值或做 min操作的维度 dim (int): 做操作的维度 keepdim (bool): 是否 keepdim """ if value is None and dim is None: return input.min() if dim is not None: return input.min(dim=dim, keepdim=keepdim) if isinstance(value, int): return input.min(dim=value, keepdim=keepdim) if isinstance(value, torch.Tensor): return minimum(input, value) return clamp(input, max=value)
class RReluFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, lower, upper, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, noise, compressed_mask = hf_bitmask.rrelu_forward(input, lower, upper, inplace) ctx.save_for_backward(noise, compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() noise, compressed_mask = ctx.saved_tensors dinput = hf_bitmask.rrelu_backward(doutput, noise, compressed_mask) return dinput, None, None, None
[docs]@_cuda_only def rrelu( input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False, inplace: bool = False ) -> Tensor: """ rrelu 函数, 参考 :class:`~hfai.nn.RReLU` """ if has_torch_function((input,)): return handle_torch_function(rrelu, (input,), input, lower=lower, upper=upper, training=training, inplace=inplace) if not training or not input.requires_grad or not torch.is_grad_enabled(): return torch._C._VariableFunctionsClass.rrelu_(input, lower, upper, training) if inplace \ else torch._C._VariableFunctionsClass.rrelu(input, lower, upper, training) if inplace and not input.is_contiguous(): return torch._C._VariableFunctionsClass.rrelu_(input, lower, upper, training) return RReluFunc.apply(input, lower, upper, inplace)
[docs]def rrelu_( input: Tensor, lower: float = 1.0 / 8, upper: float = 1.0 / 3, training: bool = False ) -> Tensor: """ 原地操作的 rrelu 函数 """ return rrelu(input, lower, upper, training)
class LeakyReluFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, negative_slope, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask = hf_bitmask.leaky_relu_forward(input, negative_slope, inplace) ctx.save_for_backward(compressed_mask) ctx.negative_slope = negative_slope return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() negative_slope = ctx.negative_slope compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.leaky_relu_backward(doutput, compressed_mask, negative_slope) return dinput, None, None, None
[docs]@_cuda_only def leaky_relu(input: Tensor, negative_slope: float = 0.01, inplace: bool = False) -> Tensor: """ leaky_relu 函数, 参考 :class:`~hfai.nn.LeakyReLU` """ if has_torch_function((input,)): return handle_torch_function(leaky_relu, (input,), input, negative_slope=negative_slope, inplace=inplace) if not torch.is_grad_enabled() or not input.requires_grad: return torch._C._nn.leaky_relu_(input, negative_slope) if inplace \ else torch._C._nn.leaky_relu(input, negative_slope) if inplace and not input.is_contiguous(): return torch._C._nn.leaky_relu_(input, negative_slope) return LeakyReluFunc.apply(input, negative_slope, inplace)
[docs]def leaky_relu_(input: Tensor, negative_slope: float = 0.01) -> Tensor: """ 原地操作的 leaky_relu 函数 """ return leaky_relu(input, negative_slope, inplace=True)
class HardSigmoid(torch.autograd.Function): @staticmethod def forward(ctx, input, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask = hf_bitmask.hardsigmoid_forward(input, inplace) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.hardsigmoid_backward(doutput, compressed_mask) return dinput, None
[docs]@_cuda_only def hardsigmoid(input: Tensor, inplace: bool = False) -> Tensor: """ hardsigmoid 函数, 参考 :class:`~hfai.nn.Hardsigmoid` """ if has_torch_function((input,)): return handle_torch_function(hardsigmoid, (input,), input, inplace=inplace) if not torch.is_grad_enabled() or not input.requires_grad: return torch._C._nn.hardsigmoid_(input) if inplace else torch._C._nn.hardsigmoid(input) if inplace and not input.is_contiguous(): return torch._C._nn.hardsigmoid_(input) return HardSigmoid.apply(input, inplace)
class HardshrinkFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, lambd): input = input.contiguous() output, compressed_mask = hf_bitmask.hardshrink_forward(input, lambd) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.shrink_backward(doutput, compressed_mask) return dinput, None
[docs]@_cuda_only @_disable_set_replace_torch def hardshrink(input: Tensor, lambd: float = 0.5) -> Tensor: """ hardshrink 函数, 参考 :class:`~hfai.nn.Hardshrink` """ if has_torch_function((input,)): return handle_torch_function(hardshrink, (input,), input, lambd=lambd) if not torch.is_grad_enabled() or not input.requires_grad: return torch.hardshrink(input, lambd) return HardshrinkFunc.apply(input, lambd)
class SoftshrinkFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, lambd): input = input.contiguous() output, compressed_mask = hf_bitmask.softshrink_forward(input, lambd) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask = ctx.saved_tensors[0] dinput = hf_bitmask.shrink_backward(doutput, compressed_mask) return dinput, None
[docs]@_cuda_only def softshrink(input: Tensor, lambd: float = 0.5) -> Tensor: """ softshrink 函数, 参考 :class:`~hfai.nn.Softshrink` """ if has_torch_function((input,)): return handle_torch_function(softshrink, (input,), input, lambd=lambd) if not torch.is_grad_enabled() or not input.requires_grad: return torch._C._nn.softshrink(input, lambd) return SoftshrinkFunc.apply(input, lambd)
class AbsFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, inplace): if inplace: ctx.mark_dirty(input) input = input.contiguous() output, compressed_mask, zero_mask = hf_bitmask.abs_forward(input, inplace) ctx.save_for_backward(compressed_mask, zero_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): doutput = doutput.contiguous() compressed_mask, zero_mask = ctx.saved_tensors dinput = hf_bitmask.abs_backward(doutput, compressed_mask, zero_mask) return dinput, None
[docs]@_cuda_only @_disable_set_replace_torch def abs(input: Tensor) -> Tensor: """ 压位 abs 函数, 用法与 func:`torch.abs` 一致 """ if not torch.is_grad_enabled() or not input.requires_grad: return input.abs() return AbsFunc.apply(input, False)
[docs]@_cuda_only @_disable_set_replace_torch def abs_(input: Tensor) -> Tensor: """ 原地操作的 abs """ if not torch.is_grad_enabled() or not input.requires_grad: return input.abs_() if not input.is_contiguous(): return input.abs_() return AbsFunc.apply(input, True)
class MaximumFunc(torch.autograd.Function): @staticmethod def forward(ctx, input1, input2): output, compressed_mask1, compressed_mask2 = hf_bitmask.maximum_forward(input1, input2) ctx.save_for_backward(compressed_mask1, compressed_mask2) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask1, compressed_mask2 = ctx.saved_tensors dinput1, dinput2 = hf_bitmask.max_min_backward(doutput, compressed_mask1, compressed_mask2) return dinput1, dinput2
[docs]@_cuda_only @_disable_set_replace_torch def maximum(input1: Tensor, input2: Tensor) -> Tensor: """ 压位 maximum 函数, 用法与 func:`torch.maximum` 一致 """ if (not torch.is_grad_enabled()) or (input1.dtype != input2.dtype) or ( not input1.requires_grad and not input2.requires_grad): return input1.maximum(input2) x, y = torch.broadcast_tensors(input1, input2) expend_numel = x.numel() if input1.shape.numel() > input2.shape.numel(): max_numel = input1.shape.numel() else: max_numel = input2.shape.numel() if (expend_numel > max_numel * 32): return input1.maximum(input2) return MaximumFunc.apply(x, y)
class MinimumFunc(torch.autograd.Function): @staticmethod def forward(ctx, input1, input2): output, compressed_mask1, compressed_mask2 = hf_bitmask.minimum_forward(input1, input2) ctx.save_for_backward(compressed_mask1, compressed_mask2) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask1, compressed_mask2 = ctx.saved_tensors dinput1, dinput2 = hf_bitmask.max_min_backward(doutput, compressed_mask1, compressed_mask2) return dinput1, dinput2
[docs]@_cuda_only @_disable_set_replace_torch def minimum(input1: Tensor, input2: Tensor) -> Tensor: """ 压位 minimum 函数, 用法与 func:`torch.minimum` 一致 """ if (not torch.is_grad_enabled()) or (input1.dtype != input2.dtype) or ( not input1.requires_grad and not input2.requires_grad): return input1.minimum(input2) x, y = torch.broadcast_tensors(input1, input2) expend_numel = x.numel() if input1.shape.numel() > input2.shape.numel(): max_numel = input1.shape.numel() else: max_numel = input2.shape.numel() if (expend_numel > max_numel * 32): return input1.minimum(input2) return MinimumFunc.apply(x, y)
class WhereFunc(torch.autograd.Function): @staticmethod def forward(ctx, condition, input1, input2): output = input1.where(condition, input2) compressed_mask = hf_bitmask.compress_mask(condition) ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask = ctx.saved_tensors[0] dinput1, dinput2 = hf_bitmask.where_backward(doutput, compressed_mask) return None, dinput1, dinput2
[docs]@_cuda_only def where(condition: Tensor, input1: Optional[Union[Tensor, Number]] = None, input2: Optional[Union[Tensor, Number]] = None) -> Union[Tensor, Tuple[Tensor]]: """ 压位 where 函数, 用法与 func:`torch.where` 一致 """ if input1 is None and input2 is None: return torch.nonzero(condition, as_tuple=True) if not isinstance(input1, Tensor): input1 = torch.tensor(input1, device=condition.device) if not isinstance(input2, Tensor): input2 = torch.tensor(input2, device=condition.device) if not torch.is_grad_enabled() or (not input1.requires_grad and not input2.requires_grad): return input1.where(condition, input2) x, y, z = torch.broadcast_tensors(condition, input1, input2) if x.shape.numel() > condition.shape.numel() * 32: return input1.where(condition, input2) return WhereFunc.apply(x, y, z)
class MaskedFillFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, mask, value, inplace): if inplace: ctx.mark_dirty(input) output = input.masked_fill_(mask, value) else: output = input.masked_fill(mask, value) compressed_mask = hf_bitmask.compress_mask(mask) value_need_grad = False if isinstance(value, Tensor) and value.requires_grad: value_need_grad = True ctx.value_need_grad = value_need_grad ctx.save_for_backward(compressed_mask) return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask = ctx.saved_tensors[0] value_need_grad = ctx.value_need_grad dinput = hf_bitmask.masked_fill_backward(doutput, compressed_mask) dvalue = None if value_need_grad: dvalue = doutput.sum() - dinput.sum() return dinput, None, dvalue, None
[docs]@_cuda_only @_disable_set_replace_torch def masked_fill(input: Tensor, mask: Tensor, value: Union[Number, Tensor]) -> Tensor: """ 压位 masked_fill 函数, 用法与 func:`torch.masked_fill` 一致 """ if not torch.is_grad_enabled() or not input.requires_grad: return input.masked_fill(mask, value) x, y = torch.broadcast_tensors(input, mask) if y.shape.numel() > mask.shape.numel() * 32: return input.masked_fill(mask, value) return MaskedFillFunc.apply(input, y, value, False)
[docs]@_cuda_only @_disable_set_replace_torch def masked_fill_(input: Tensor, mask: Tensor, value: Union[Number, Tensor]) -> Tensor: """ 原地操作的压位 masked_fill 函数 """ if not torch.is_grad_enabled() or not input.requires_grad: return input.masked_fill_(mask, value) x, y = torch.broadcast_tensors(input, mask) if y.shape.numel() > mask.shape.numel() * 32: return input.masked_fill_(mask, value) return MaskedFillFunc.apply(input, y, value, True)
class MaskedSelectFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, mask): output = input.masked_select(mask) compressed_mask = hf_bitmask.compress_mask(mask) ctx.save_for_backward(compressed_mask) ctx.n = mask.numel() ctx.input_options = {'size': input.size(), 'dtype': input.dtype, 'layout': input.layout, 'device': input.device} return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask = ctx.saved_tensors[0] mask = hf_bitmask.decompress_mask(compressed_mask, ctx.n) input_options = ctx.input_options dinput = torch.zeros(size=input_options['size'], dtype=input_options['dtype'], layout=input_options['layout'], device=input_options['device']) dinput.masked_scatter_(mask, doutput) return dinput, None
[docs]@_cuda_only @_disable_set_replace_torch def masked_select(input: Tensor, mask: Tensor) -> Tensor: """ 节省显存的 masked_select 函数, 用法与 func:`torch.masked_select` 一致 """ if not torch.is_grad_enabled() or not input.requires_grad: return input.masked_select(mask) x, y = torch.broadcast_tensors(input, mask) if y.shape.numel() > mask.shape.numel() * 32: return input.masked_select(mask) return MaskedSelectFunc.apply(x, y)
class MaskedScatterFunc(torch.autograd.Function): @staticmethod def forward(ctx, input, mask, source, inplace): if inplace: ctx.mark_dirty(input) output = input.masked_scatter_(mask, source) else: output = input.masked_scatter(mask, source) compressed_mask = hf_bitmask.compress_mask(mask) ctx.save_for_backward(compressed_mask) ctx.n = mask.numel() ctx.source_size = source.size() ctx.source_numel = source.numel() ctx.input_requires = input.requires_grad ctx.source_requires = source.requires_grad return output @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, doutput): compressed_mask = ctx.saved_tensors[0] mask = hf_bitmask.decompress_mask(compressed_mask, ctx.n) if ctx.input_requires: dinput = doutput.masked_fill(mask, 0) else: dinput = None if ctx.source_requires: mask_selected = doutput.masked_select(mask) diff_nelem = ctx.source_numel - mask_selected.numel() if diff_nelem > 0: # because mask_selected returns a 1-d tensor with size of masked elements that are 1, # we need to fill out ther rest with zeros then reshape back to tensor2's size. mask_selected = torch_F.pad(mask_selected, [0, diff_nelem], value=0) dsource = mask_selected.view(ctx.source_size) else: dsource = None return dinput, None, dsource, None
[docs]@_cuda_only @_disable_set_replace_torch def masked_scatter(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: """ 节省显存的 masked_scatter 函数, 用法与 func:`torch.masked_scatter` 一致 """ if not torch.is_grad_enabled() or (not input.requires_grad and not source.requires_grad): return input.masked_scatter(mask, source) x, y = torch.broadcast_tensors(input, mask) if y.shape.numel() > mask.shape.numel() * 32: return input.masked_scatter(mask, source) return MaskedScatterFunc.apply(input, y, source, False)
[docs]@_cuda_only @_disable_set_replace_torch def masked_scatter_(input: Tensor, mask: Tensor, source: Tensor) -> Tensor: """ 原地操作的节省显存的 masked_scatter 函数 """ if not torch.is_grad_enabled() or (not input.requires_grad and not source.requires_grad): return input.masked_scatter_(mask, source) x, y = torch.broadcast_tensors(input, mask) if y.shape.numel() > mask.shape.numel() * 32: return input.masked_scatter_(mask, source) return MaskedScatterFunc.apply(input, y, source, True)
def canonicalize_axis(dim: int, num_dims: int) -> int: """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" if not -num_dims <= dim < num_dims: raise ValueError(f"Dim {dim} is out of bounds for array of dimension {num_dims}") if dim < 0: dim = dim + num_dims return dim X = TypeVar('X') Y = TypeVar('Y') Hidden = TypeVar('Hidden')
[docs]def scan(f: Callable[[X, Hidden], Tuple[Y, Hidden]], xs: X, init: Hidden, dim: int = 0, reverse: bool = False) -> Tuple[Y, Hidden]: """ 用给定函数在数据上扫描 (类似 RNN), 并在相邻的阶段间传递隐藏状态 ``hidden`` Args: f (Callable[Tuple[X, Hidden], Tuple[Y, Hidden]]): 每个阶段执行的函数, 形如 ``f: (x, hidden_in) -> (y, hidden_out)``. ``X`` 和 ``Y`` 都是 ``Tensor`` 或者叶子节点是 ``Tensor`` 的 pytree (嵌套的 tuple/list/dict). ``i`` 阶段的 ``hidden_out`` 是 ``i + 1`` 阶段的 ``hidden_in`` xs (X): 在主维堆叠在一起的 ``x``. 阶段数 (``f`` 的执行次数) 等于主维大小 init (Hidden): 第一个阶段输入的 ``hidden_in`` dim (int, optional): 主维. 默认: 0 reverse (bool, optional): 是否逆序扫描 ``xs``. 默认: False Returns: 一个元组 ``(ys, h)``, ``ys`` 表示扫描结果; ``h`` 是最后一个阶段输出的 ``hidden_out`` Examples: .. code-block:: python f = lambda x, y: (x + y, x + y) a = torch.tensor([1, 2, 3, 4]) sum = hfai.nn.functional.scan(f, a, torch.tensor(0)) # (tensor([1, 3, 6, 10]), tensor(10)) sum = hfai.nn.functional.scan(f, a, torch.tensor(0), reverse=True) # (tensor([10, 9, 7, 4]), tensor(10)) .. code-block:: python f = lambda x, y: ((x[0] + y, x[1] + y), x[0] + x[1] + y) a = [torch.tensor([1, 2, 3, 4]), torch.tensor([4, 3, 2, 1])] sum = hfai.nn.functional.scan(f, a, torch.tensor(0)) # ([tensor([1, 7, 13, 19]), tensor([4, 8, 12, 16])], tensor(20)) sum = hfai.nn.functional.scan(f, a, torch.tensor(0), reverse=True) # ([tensor([16, 12, 8, 4]), tensor([19, 13, 7, 1])], tensor(20)) """ if not callable(f): raise TypeError(f"f 必须是一个函数") xs_flat, tree_x = pytree.tree_flatten(xs) if reverse: xs_flat = [torch.flip(x, [dim]) for x in xs_flat] # Check that all inputs have a consistent leading dimension `num_xs`. dim = canonicalize_axis(dim, xs_flat[0].ndim) num_xs = int(xs_flat[0].shape[dim]) if not all(int(x.shape[dim]) == num_xs for x in xs_flat[1:]): raise ValueError('scan 的 xs 的主维必须相同, (当前是: {})'.format([x.shape for x in xs_flat])) hidden = init is_ys_init = False for i in range(num_xs): x_flat = [torch.select(x_flat, dim=dim, index=i) for x_flat in xs_flat] y, hidden = f(pytree.tree_unflatten(x_flat, tree_x), hidden) y_flat, tree_y = pytree.tree_flatten(y) if not is_ys_init: is_ys_init = True ys_flat = [[y] for y in y_flat] else: for ys, y in zip(ys_flat, y_flat): ys.append(y) ys_flat = list(map(partial(torch.stack, dim=dim), ys_flat)) if reverse: ys_flat = [torch.flip(ys, [dim]) for ys in ys_flat] return pytree.tree_unflatten(ys_flat, tree_y), hidden
[docs]def associative_scan(f: Callable[[X, X], X], xs: X, dim: int = 0, reverse: bool = False): """ 用满足结合律的二元运算函数在数据上扫描 (类似前缀和), 并行执行 Args: f (Callable[Tuple[X, X], X]): 二元运算函数, 需要满足结合律, 即 ``f(f(a, b), c) = f(a, f(b, c))``. ``X`` 是 ``Tensor`` 或者叶子节点是 ``Tensor`` 的 pytree (嵌套的 tuple/list/dict). ``f`` 的输入和输出必须结构相同 (如果 ``X`` 是 ``Tensor``, 那么输入和输出必须 ``shape`` 相同; 如果 ``X`` 是 ``pytree``, 那么输入和输出的 ``pytree`` 必须同构, 并且对应的叶子节点的 ``Tensor`` 必须 ``shape`` 相同) xs (X): 在主维堆叠在一起的 ``x``. dim (int, optional): 主维. 默认: 0 reverse (bool, optional): 是否逆序扫描 ``xs``. 默认: False Returns: ``ys``, 表示扫描结果, 与 ``xs`` 结构相同 Examples: .. code-block:: python f = lambda x, y: x + y a = torch.tensor([1, 2, 3, 4]) sum = hfai.nn.functional.associative_scan(f, a) # tensor([1, 3, 6, 10]) sum = hfai.nn.functional.associative_scan(f, a, reverse=True) # tensor([10, 9, 7, 4]) .. code-block:: python f = lambda x, y: (y[0] * x[0], y[0] * x[1] + y[1]) # 满足结合律 a = [torch.randn(5, 6, 7), torch.randn(5, 6, 7)] sum = hfai.nn.functional.associative_scan(f, a, dim=-1, reverse=True) """ if not callable(f): raise TypeError(f"f 必须是一个函数") xs_flat, tree = pytree.tree_flatten(xs) if reverse: xs_flat = [torch.flip(x, [dim]) for x in xs_flat] # Check that all inputs have a consistent leading dimension `num_xs`. dim = canonicalize_axis(dim, xs_flat[0].ndim) num_xs = int(xs_flat[0].shape[dim]) if not all(int(x.shape[dim]) == num_xs for x in xs_flat[1:]): raise ValueError('associative_scan 的 xs 的主维必须相同, (当前是: {})'.format([x.shape for x in xs_flat])) def _combine(a_flat, b_flat): a = pytree.tree_unflatten(a_flat, tree) b = pytree.tree_unflatten(b_flat, tree) c = f(a, b) c_flat, _ = pytree.tree_flatten(c) return c_flat def _slice_in_dim(start, limit, stride=1): slices = [slice(0, None)] * (dim + 1) slices[dim] = slice(start, limit, stride) return slices def _scan(xs): num_xs = xs[0].shape[dim] if num_xs <= 1: return xs # Combine adjacent pairs of elements. reduced_xs = _combine([x[_slice_in_dim(0, -1, 2)] for x in xs], [x[_slice_in_dim(1, None, 2)] for x in xs]) # Recursively compute scan for partially reduced tensors. odd_xs = _scan(reduced_xs) if num_xs % 2 == 0: even_xs = _combine([odd_x[_slice_in_dim(0, -1)] for odd_x in odd_xs], [x[_slice_in_dim(2, None, 2)] for x in xs]) else: even_xs = _combine(odd_xs, [x[_slice_in_dim(2, None, 2)] for x in xs]) ys = [torch.empty_like(x) for x in xs] for (x, odd_x, even_x, y) in zip(xs, odd_xs, even_xs, ys): y[_slice_in_dim(0, 1)] = x[_slice_in_dim(0, 1)] y[_slice_in_dim(1, None, 2)] = odd_x y[_slice_in_dim(2, None, 2)] = even_x return ys ys_flat = _scan(xs_flat) if reverse: ys_flat = [torch.flip(y, [dim]) for y in ys_flat] return pytree.tree_unflatten(ys_flat, tree)
# Copyright 2020 Google LLC # Copyright 2021 Teddy Koker # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. no_softrank = False try: import hfai.hfcuda.softrank as softrank except: no_softrank = True def soft_rank(values, regularization="l2", regularization_strength=1.0): if len(values.shape) != 2: raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}") if regularization not in ["l2", "kl"]: raise ValueError(f"'regularization' should be a 'l2' or 'kl'") return SoftRank.apply(values, regularization, regularization_strength) def soft_sort(values, regularization="l2", regularization_strength=1.0): if len(values.shape) != 2: raise ValueError(f"'values' should be a 2d-tensor but got {values.shape}") if regularization not in ["l2", "kl"]: raise ValueError(f"'regularization' should be a 'l2' or 'kl'") return SoftSort.apply(values, regularization, regularization_strength) def _arange_like(x, reverse=False): # returns arange with len of x of the same dtype and device (assumes 2d, first dim batch) if reverse: ar = torch.arange(x.shape[1] - 1, -1, -1, dtype=x.dtype, device=x.device) else: ar = torch.arange(x.shape[1], dtype=x.dtype, device=x.device) return ar.expand(x.shape[0], -1) def _inv_permutation(permutation): # returns inverse permutation of 'permutation'. (assumes 2d, first dim batch) inv_permutation = torch.zeros_like(permutation) inv_permutation.scatter_(1, permutation, _arange_like(permutation)) return inv_permutation # The following is from google-research/fast-soft-sort with the following modifications: # - replace numpy functions with torch equivalent # - remove uncessary operations # - reimplement backward pass in C++ class SoftRank(torch.autograd.Function): @staticmethod def forward(ctx, tensor, regularization="l2", regularization_strength=1.0): ctx.scale = 1.0 / regularization_strength ctx.regularization = regularization w = _arange_like(tensor, reverse=True) + 1 theta = tensor * ctx.scale s, permutation = torch.sort(theta, descending=True) inv_permutation = _inv_permutation(permutation) if ctx.regularization == "l2": dual_sol = softrank.isotonic_l2(s - w) ret = (s - dual_sol).gather(1, inv_permutation) factor = torch.tensor(1.0, device=s.device) else: dual_sol = softrank.isotonic_kl(s, torch.log(w)) ret = torch.exp((s - dual_sol).gather(1, inv_permutation)) factor = ret ctx.save_for_backward(factor, s, dual_sol, permutation, inv_permutation) return ret @staticmethod def backward(ctx, grad_output): factor, s, dual_sol, permutation, inv_permutation = ctx.saved_tensors grad = (grad_output * factor).clone() if ctx.regularization == "l2": grad -= softrank.isotonic_l2_backward( s, dual_sol, grad.gather(1, permutation) ).gather(1, inv_permutation) else: grad -= softrank.isotonic_kl_backward( s, dual_sol, grad.gather(1, permutation) ).gather(1, inv_permutation) return grad * ctx.scale, None, None class SoftSort(torch.autograd.Function): @staticmethod def forward(ctx, tensor, regularization="l2", regularization_strength=1.0): ctx.sign = -1 ctx.regularization = regularization w = (_arange_like(tensor, reverse=True) + 1) / regularization_strength tensor = ctx.sign * tensor # for ascending s, permutation = torch.sort(tensor, descending=True) # note reverse order of args if ctx.regularization == "l2": sol = softrank.isotonic_l2(w - s) else: sol = softrank.isotonic_kl(w, s) ctx.save_for_backward(s, sol, permutation) return ctx.sign * (w - sol) @staticmethod def backward(ctx, grad_output): s, sol, permutation = ctx.saved_tensors inv_permutation = _inv_permutation(permutation) if ctx.regularization == "l2": grad = softrank.isotonic_l2_backward(s, sol, grad_output) else: grad = softrank.isotonic_kl_backward(s, sol, grad_output) return grad.gather(1, inv_permutation), None, None try: import hfai.hfcuda.linear as hf_linear except: pass def linear(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: if has_torch_function((input,)): return handle_torch_function(linear, (input,), input, weight=weight, bias=bias) if bias is None or not torch.backends.cuda.matmul.allow_tf32: return torch_F.linear(input, weight, bias) return hf_linear.linear(input, weight, bias) try: import hfai.hfcuda.swiglu as hf_swiglu except: pass class SwiGLUFunction(torch.autograd.Function): @staticmethod def forward(ctx, input, dim): ctx.save_for_backward(input) ctx.dim = dim output = hf_swiglu.swiglu_forward(input, dim) return output @staticmethod def backward(ctx, doutput): input = ctx.saved_tensors[0] dim = ctx.dim dinput = hf_swiglu.swiglu_backward(doutput, input, dim) return dinput, None def swiglu(input: Tensor, dim=-1) -> Tensor: return SwiGLUFunction.apply(input, dim)
[docs]class set_replace_torch(object): """ 把所有 hfai 优化过的 torch.nn.functional 和 torch 中的函数转换为 hfai.nn.functional 的对应函数. 如: torch.nn.functional.relu -> hfai.nn.functional.relu, torch.max -> hfai.nn.funtional.max, x.abs() -> hfai.nn.funtional.abs(x) Args: mode (bool, optional): 是否开启替换. 默认: True .. note:: 调用 ``hfai.nn.functional.set_replace_torch()`` 后, 无论是 (1)用户显示调用 还是 (2)PyTorch内部调用, 一切对 ``torch.nn.functional.xxx`` 的调用, 都会执行 ``hfai.nn.functional.xxx`` . 例如 ``torch.nn.CrossEntropyLoss`` 中的 ``log_softmax`` 会自动执行 ``hfai.nn.functional.log_softmax`` Examples: .. code-block:: python hfai.nn.functional.set_replace_torch() y = torch.nn.functional.softmax(x) # softmax 执行 hfai 的实现 hfai.nn.functional.set_replace_torch(False) with hfai.nn.functional.set_replace_torch(): loss = torch.nn.functional.cross_entropy(input, target) # 内部的 log_softmax 执行 hfai 的实现 """ # 所有可以无缝替换 torch.nn.functional.xxx 的函数 replaceable_nn_functions = \ [dropout, relu, relu_, hardtanh, hardtanh_, relu6, softplus, softmin, softmax, log_softmax, _threshold, threshold, threshold_, rrelu, rrelu_, leaky_relu, leaky_relu_, hardsigmoid, hardshrink, softshrink] replaceable_nn_functions_torch = [getattr(torch_F, func.__name__) for func in replaceable_nn_functions] # 所有可以无缝替换 torch.xxx 的函数 replaceable_torch_functions = \ [minimum, maximum, abs, abs_, min, max, clip, clip_, clamp, clamp_, clamp_max, clamp_max_, clamp_min, clamp_min_, where, masked_fill, masked_select, masked_scatter] replaceable_torch_functions_torch = [getattr(torch, func.__name__) for func in replaceable_torch_functions] # 所有可以无缝替换 torch.Tensor.xxx 的函数 replaceable_tensor_functions = \ [relu, relu_, softmax, log_softmax, hardshrink, minimum, maximum, abs, abs_, min, max, clip, clip_, clamp, clamp_, clamp_max, clamp_max_, clamp_min, clamp_min_, masked_fill, masked_fill_, masked_select, masked_scatter, masked_scatter_] replaceable_tensor_functions_torch = [getattr(torch.Tensor, func.__name__) for func in replaceable_tensor_functions] is_replacement_enabled = False @staticmethod def _set(cls, mode: bool): cls.is_replacement_enabled = mode for func in (cls.replaceable_nn_functions if mode is True else cls.replaceable_nn_functions_torch): setattr(torch_F, func.__name__, func) for func in (cls.replaceable_torch_functions if mode is True else cls.replaceable_torch_functions_torch): setattr(torch, func.__name__, func) for func in (cls.replaceable_tensor_functions if mode is True else cls.replaceable_tensor_functions_torch): setattr(torch.Tensor, func.__name__, func) def __init__(self, mode: bool = True): self.prev = self.__class__.is_replacement_enabled self._set(self.__class__, mode) def __enter__(self): pass def __exit__(self, exc_type, exc_val, exc_tb): self._set(self.__class__, self.prev)