Source code for hfai.nn.modules.dropout
from torch import nn
import torch
from .. import functional as F
[docs]class Dropout(nn.Dropout):
"""
压位 Dropout 算子, 训练时的 mask 用 1bit 储存, 以节省训练时的内存
使用方式与 `PyTorch 的 Dropout <https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html>`_ 一致
Args:
p (float, optional): 元素为零的概率. 默认: 0.5
inplace (bool, optional): 如果是 ``True``, 进行原地操作. 默认: ``False``
Examples:
.. code-block:: python
m = hfai.nn.Dropout(p=0.2)
input = torch.randn(20, 16)
output = m(input)
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.dropout(input, self.p, self.training, self.inplace)