
Source code for hfai.utils.profile

from contextlib import contextmanager, ExitStack
from tabulate import tabulate
from packaging import version
import torch
from collections import defaultdict
import time

[docs]def profile_memory(model, input=(), input_kwargs={}, include_children=True, sort_by="name", show_shapes=False, show_peakmem=False, show_forward_time=False, forward_funcs=["forward"]): """ 分析模型的显存占用情况 打印出来的结果包含以下几个字段: 1. ``parameter size``: 参数总量 #. ``activation size``: forward 的过程中通过 ``save_for_backward`` 保存的 tensor 大小(不包含参数) #. ``#calls``: 被调用的次数 #. ``input shape``: 输入的 tensor 形状 #. ``output shape``: 输出的 tensor 形状 #. ``peak mem``: 峰值显存;forward 过程中的峰值显存减去 forward 之前的已占用显存 #. ``forward time``: 模块时间:每个模块 forward 前后的时间差,多次调用则累加 NOTE: 不同算子可能会重复保存一部分的中间层变量,所以总的 ``activation size`` 会比实际的显存使用量要大。 NOTE: ``show_peakmem = True`` 和 ``include_children = False`` 互斥 NOTE: ``show_forward_time = True`` 和 ``include_children = False`` 互斥 NOTE: 仅支持 PyTorch >= 1.10 Args: model (torch.nn.Module): 需要被分析的模型 input (tuple): 模型的输入,通过 ``model(*input, **input_kwargs)`` 调用 input_kwargs (dict): 模型的关键字参数,通过 ``model(*input, **input_kwargs)`` 调用 include_children (bool): 每个模块的显存占用计算是否包含其子模块(类型为 ``nn.Module``)的显存占用;默认是 ``True`` sort_by (str): ``name``, ``activation``, ``parameter``, ``peakmem`` 或者 ``forward_time``,输出的时候根据哪个字段进行排序;默认是 ``name`` show_shapes (bool): 是否打印输入、输出的形状;默认是 ``False`` show_peakmem (bool): 是否打印峰值显存,模型必须在 GPU 上;默认是 ``False`` show_forward_time (bool): 是否打印模块时间 Examples: >>> import torch, hfai >>> from torchvision import models >>> model = models.alexnet().cuda() >>> x = torch.randn(64, 3, 224, 224, device="cuda") >>> hfai.utils.profile_memory(model, input=(x,)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ HFAI Memory Profiler (include_children = True, sort_by = name) ==================== ================= ================ ================= ======== module name type parameter size activation size #calls ==================== ================= ================ ================= ======== AlexNet AlexNet 233.081 MiB 345.312 MiB 1 AlexNet.features Sequential 9.421 MiB 336.000 MiB 1 AlexNet.features.0 Conv2d 0.089 MiB 36.750 MiB 1 ...... AlexNet.classifier.5 ReLU 0.000 MiB 1.000 MiB 1 AlexNet.classifier.6 Linear 15.629 MiB 1.000 MiB 1 ==================== ================= ================ ================= ======== total unique activations: 225.906 MiB ====================================================================================== """ if version.parse(torch.__version__) < version.parse("1.10.0"): raise RuntimeError("hfai.utils.profile_memory 只支持 PyTorch >= 1.10") assert isinstance(input, tuple) assert isinstance(input_kwargs, dict) assert sort_by in ["parameter", "activation", "peakmem", "name", "forward_time"] assert include_children or (not show_peakmem), "include_children = False 时不支持 show_peakmem = True" assert include_children or (not show_forward_time), "include_children = False 时不支持 show_forward_time = True" device = next(model.parameters()).device assert not show_peakmem or device != torch.device("cpu"), "show_peakmem = True 不支持 CPU 上的模型" assert sort_by != "peakmem" or show_peakmem, "show_peakmem = False 时不支持 sort_by = peakmem" assert sort_by != "forward_time" or show_forward_time stats = [] model_name = model.__class__.__name__ time_details = defaultdict(float) iters = 100 if show_forward_time: # warmup for i in range(10): model(*input, **input_kwargs) torch.cuda.synchronize() time_profiler = TimeProfiler() # backup forward, replace forward using wrap_func for name, module in model.named_modules(): for func_name in forward_funcs: if hasattr(module, func_name): func = getattr(module, func_name) setattr(module, "_hfai_orig_" + func_name, func) setattr(module, func_name, time_profiler.wrap_func(name, module, func)) start_time = time.time() for i in range(iters): model(*input, **input_kwargs) torch.cuda.synchronize() end_time = time.time() for name in time_profiler.event_dict: full_name = (model_name + '.' + name) if name else model_name for start, end in time_profiler.event_dict[name]: time_details[full_name] += start.elapsed_time(end) # recover forward for name, module in model.named_modules(): for func_name in forward_funcs: if hasattr(module, func_name): func = getattr(module, "_hfai_orig_" + func_name) setattr(module, func_name, func) profiler = MemoryProfiler(model, include_children) # backup forward, replace it using wrap_func for name, module in model.named_modules(): for func_name in forward_funcs: if hasattr(module, func_name): func = getattr(module, func_name) setattr(module, "_hfai_orig_" + func_name, func) setattr(module, func_name, profiler.wrap_func(name, module, func)) with profiler.profile(): model(*input, **input_kwargs) for name, (module, asize, ncalls, shapes, out_shapes, peak_mem) in profiler.module_stats.items(): full_name = (model_name + '.' + name) if name else model_name psize = sum(p.numel() * p.element_size() for p in module.parameters(recurse=include_children)) typename = type(module).__name__ stats.append( (full_name, typename, psize, asize, ncalls, peak_mem, shapes, out_shapes, time_details.get(full_name, 0) / iters) ) # sort if sort_by == "parameter": stats.sort(key=lambda x: (x[2], x[3], x[0]), reverse=True) elif sort_by == "activation": stats.sort(key=lambda x: (x[3], x[2], x[0]), reverse=True) elif sort_by == "peakmem": stats.sort(key=lambda x: (x[5], x[0]), reverse=True) elif sort_by == "forward_time": stats.sort(key=lambda x: (x[8], x[0]), reverse=True) table = [] for n, typename, psize, asize, ncalls, peak_mem, shapes, out_shapes, forward_time in stats: row = [n, typename, format_size(psize), format_size(asize), ncalls, format_size(peak_mem), shapes, out_shapes, f'{forward_time:.3f} ms'] table.append(row) headers = ["module name", "type", "parameter size", "activation size", "#calls", "peak mem", "input shape", "ouptut shape", "forward time"] colalign = ["left", "left", "right", "right", "right", "right", "left", "left", "right"] def pop(index): for row in table: row.pop(index) headers.pop(index) colalign.pop(index) if not show_forward_time: pop(8) if not show_shapes: pop(7), pop(6) if not show_peakmem: pop(5) table = tabulate(table, headers=headers, colalign=colalign, tablefmt="rst") total_unique_activation_size = format_size(profiler.unique_activation_size) msg = f"HFAI Memory Profiler (include_children = {include_children}, sort_by = {sort_by})\n" msg += str(table) + "\n" line_width = len(str(table).split("\n")[0]) msg = line_width * "^" + "\n" + msg n_parameters = sum(p.numel() * p.element_size() for p in model.parameters()) msg += f"total params size: {format_size(n_parameters)}\n" msg += f"total unique activations: {total_unique_activation_size}\n" if show_forward_time: msg += f"total forward time per iter: {(end_time - start_time) / iters * 1000:.3f} ms\n" msg += line_width * "=" + "\n" print(msg, flush=True) for name, module in model.named_modules(): for func_name in forward_funcs: if hasattr(module, func_name): func = getattr(module, "_hfai_orig_" + func_name) setattr(module, func_name, func)
def format_size(size): return f"{size / (1 << 20):.3f} MiB" class CudaMemoryStats(): def __init__(self) -> None: self.max_mem = 0 @contextmanager def reset_peak_memory_stats(self): self.max_mem = max(torch.cuda.max_memory_allocated(), self.max_mem) prev_max_mem = self.max_mem try: torch.cuda.reset_peak_memory_stats() self.max_mem = torch.cuda.max_memory_allocated() yield finally: self.max_mem = max(prev_max_mem, self.max_mem, torch.cuda.max_memory_allocated()) def max_memory_allocated(self): self.max_mem = max(torch.cuda.max_memory_allocated(), self.max_mem) return self.max_mem class TimeProfiler(): def __init__(self) -> None: self.event_dict = defaultdict(list) def wrap_func(self, name, module, func): def wrapped_func(*args, **kwargs): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) # record forward time using event start_event.record() outputs = func(*args, **kwargs) end_event.record() self.event_dict[name].append((start_event, end_event)) return outputs return wrapped_func class MemoryProfiler(): def __init__(self, model, include_children=True) -> None: self.module_stats = {} self.hooks = SavedTensorsHooks(include_children) self.include_children = include_children self.seen_params = set() self.seen_acts = set() self.cuda_mem_stats = CudaMemoryStats() for p in model.parameters(): storage = if storage.data_ptr() not in self.seen_params: self.seen_params.add(storage.data_ptr()) self.seen_acts.add(storage.data_ptr()) self.unique_activation_size = 0 def wrap_func(self, name, module, func): def pack_hook(tensor): storage = if storage.data_ptr() not in self.seen_params: nbytes = tensor.numel() * tensor.element_size() self.module_stats[name][1] += nbytes if storage.data_ptr() not in self.seen_acts: self.seen_acts.add(storage.data_ptr()) self.unique_activation_size += storage.size() * storage.element_size() return tensor def unpack_hook(tensor): return tensor if name not in self.module_stats: # [module, activation size, #calls, input shape, output shape, peak memory] self.module_stats[name] = [module, 0, 0, '', '', 0] def wrapped_func(*args, **kwargs): mem = torch.cuda.memory_allocated() with self.cuda_mem_stats.reset_peak_memory_stats(): with self.hooks.enable_hook(pack_hook, unpack_hook): outputs = func(*args, **kwargs) peak_mem = self.cuda_mem_stats.max_memory_allocated() - mem self.module_stats[name][2] += 1 self.module_stats[name][3] = format_input_shape(args, kwargs) self.module_stats[name][4] = format_output_shape(outputs) self.module_stats[name][5] = peak_mem return outputs return wrapped_func @contextmanager def profile(self): with ExitStack() as stack: stack.enter_context(self.hooks.saved_tensors_hooks()) yield return def format_input_shape(args, kwargs): shapes = [] for obj in args: shapes.append(format_tensor(obj)) for k, v in kwargs.items(): s = format_tensor(v) shapes.append(f"{k}={s}") msg = ", ".join([str(s) for s in shapes]) return msg def format_output_shape(outputs): shapes = [] if isinstance(outputs, torch.Tensor): shapes.append(format_tensor(outputs)) elif isinstance(outputs, (tuple, list)): for out in outputs: shapes.append(format_tensor(out)) else: return str(format_tensor(outputs)) msg = ", ".join([str(s) for s in shapes]) return msg def format_tensor(obj): if isinstance(obj, torch.Tensor): return tuple(obj.shape) if isinstance(obj, (tuple, list)): return type(obj)(format_tensor(x) for x in obj) if isinstance(obj, dict): return {k: format_tensor(v) for k, v in obj.items()} if isinstance(obj, (int, float, str, bool, type(None))): return obj return "[UNKOWN]" class SavedTensorsHooks(): def __init__(self, include_children=True) -> None: self.hooks = [] self.current_hook = None self.include_children = include_children @contextmanager def enable_hook(self, pack_hook, unpack_hook): parent_hook = None try: if not self.include_children and len(self.hooks) > 0: parent_hook = self.hooks.pop() self.hooks.append((pack_hook, unpack_hook)) yield finally: self.hooks.pop() if parent_hook: self.hooks.append(parent_hook) def pack_hook(self, tensor): for hook in reversed(self.hooks): tensor = hook[0](tensor) return tensor def unpack_hook(self, tensor): for hook in self.hooks: tensor = hook[1](tensor) return tensor @contextmanager def saved_tensors_hooks(self): with ExitStack() as stack: context = torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook) stack.enter_context(context) yield return