Source code for haiscale.cpu_offload.cpu_offload

import torch
import logging

class TagInfo:
    定义 tag 的相关信息

        tag (str): Tag 名称

    def __init__(self, tag: str):
        self.tag = tag
        # 该 Tag 是否为第一次 iteration 计算
        self.is_first_iter = True
        # 该 Tag 中中间变量的显存总量
        self.total_activations_memory = 0.0
        # 该 Tag 下已经 offload 的显存量
        self.current_offload_memory = 0.0 = torch.cuda.Stream()

    def reset_current_offload_memory(self):
        self.current_offload_memory = 0.0
        if self.is_first_iter:
                f"[{self.tag}] total_activations_memory: {self.total_activations_memory}"
        self.is_first_iter = False

[docs]class CPUOffload(): """ 将标注为 tag 的模型按照 offload_ratio 比例进行中间变量的 H2D 和 D2H,以减少显存的占用。 Args: offload_ratio (float): Offload 的比例 tag (str): Tag 名称 Examples: .. code-block:: python from haiscale.cpu_offload import CPUOffload import torch.nn as nn import torchvision class Net(nn.Module): def __init__(self, nhiddens, nlayers): super(Net, self).__init__() self.model = torchvision.models.resnet50() def forward(self, x): with CPUOffload(offload_ratio=0.1, tag="resnet"): x = self.model(x) return x """ # 记录所有的 Tag 信息 _tag_infos = {} def __init__(self, offload_ratio: float, tag: str): super().__init__() self.offload_ratio = offload_ratio self.tag = tag # 低版本的 torch 不做 offload if getattr(torch.autograd, 'graph', None) is None: self.hook = None logging.warning( f'Current torch version is {torch.__version__}, not support torch.autograd.graph.saved_tensors_hooks. ' f'Please upgrade your torch version.') else: if tag not in self._tag_infos:"Regist {tag}, offload ratio is {offload_ratio}") self._tag_infos[tag] = TagInfo(tag) self.hook = torch.autograd.graph.saved_tensors_hooks( self.offload_hook, self.load_hook ) # 用于 H2D 和 D2H 的 stream = self._tag_infos[tag].stream def __enter__(self): # Tag 注册 Offload 钩子 if self.hook: self.hook.__enter__() def __exit__(self, exc_type, exc_value, traceback): # Tag 关闭 Offload 钩子 if self.hook: self._tag_infos[self.tag].reset_current_offload_memory() self.hook.__exit__(exc_type, exc_value, traceback) @classmethod def get_tensor_memory(cls, x: torch.Tensor): """ 获取 tensor 的显存占用量 Args: x (str): 需要计算显存的 tensor Returns: tensor_memory (int): Tensor 的显存占用量 """ return x.element_size() * x.nelement() def offload_hook(self, x: torch.Tensor): """ 按照 offload_ratio 的比例将中间变量从 GPU 上移动到 CPU 上,于 forward 产生中间变量后执行 Args: x (torch.Tensor): 中间变量 tensor Returns: packed_tuple (tuple): 记录原本的 device 和 offload 后的数据,格式为 (device, offload_x) """ tag_info = self._tag_infos[self.tag] tensor_memory = self.get_tensor_memory(x) # first forward if tag_info.is_first_iter: # calculate total activations gpu memory tag_info.total_activations_memory += tensor_memory packed_tuple = (x.device,"cpu", non_blocking=True)) # not first forward and should be offload elif ( (tag_info.current_offload_memory + tensor_memory) / tag_info.total_activations_memory ) <= self.offload_ratio: tag_info.current_offload_memory += tensor_memory with packed_tuple = (x.device,"cpu", non_blocking=True)) # not first forward and should not be offload else: packed_tuple = (None, x) return packed_tuple def load_hook(self, packed_tuple: tuple): """ 将中间变量从 CPU 载入 GPU Args: packed_tuple (tuple): 记录原本的 device 和 offload 后的数据,格式为 (device, offload_x) Returns: x (torch.Tensor): 载入原本 device 的中间变量 tensor """ device, x = packed_tuple if device is not None and x.device != device: event = torch.cuda.Event() with x =, non_blocking=True) event.record() event.wait() return x