from contextlib import contextmanager, ExitStack
from tabulate import tabulate
from packaging import version
import torch
from collections import defaultdict
import time
[docs]def profile_memory(model, input=(), input_kwargs={}, include_children=True, sort_by="name",
show_shapes=False, show_peakmem=False, show_forward_time=False, forward_funcs=["forward"]):
"""
分析模型的显存占用情况
打印出来的结果包含以下几个字段:
1. ``parameter size``: 参数总量
#. ``activation size``: forward 的过程中通过 ``save_for_backward`` 保存的 tensor 大小(不包含参数)
#. ``#calls``: 被调用的次数
#. ``input shape``: 输入的 tensor 形状
#. ``output shape``: 输出的 tensor 形状
#. ``peak mem``: 峰值显存;forward 过程中的峰值显存减去 forward 之前的已占用显存
#. ``forward time``: 模块时间:每个模块 forward 前后的时间差,多次调用则累加
NOTE:
不同算子可能会重复保存一部分的中间层变量,所以总的 ``activation size`` 会比实际的显存使用量要大。
NOTE:
``show_peakmem = True`` 和 ``include_children = False`` 互斥
NOTE:
``show_forward_time = True`` 和 ``include_children = False`` 互斥
NOTE:
仅支持 PyTorch >= 1.10
Args:
model (torch.nn.Module): 需要被分析的模型
input (tuple): 模型的输入,通过 ``model(*input, **input_kwargs)`` 调用
input_kwargs (dict): 模型的关键字参数,通过 ``model(*input, **input_kwargs)`` 调用
include_children (bool): 每个模块的显存占用计算是否包含其子模块(类型为 ``nn.Module``)的显存占用;默认是 ``True``
sort_by (str): ``name``, ``activation``, ``parameter``, ``peakmem`` 或者 ``forward_time``,输出的时候根据哪个字段进行排序;默认是 ``name``
show_shapes (bool): 是否打印输入、输出的形状;默认是 ``False``
show_peakmem (bool): 是否打印峰值显存,模型必须在 GPU 上;默认是 ``False``
show_forward_time (bool): 是否打印模块时间
Examples:
>>> import torch, hfai
>>> from torchvision import models
>>> model = models.alexnet().cuda()
>>> x = torch.randn(64, 3, 224, 224, device="cuda")
>>> hfai.utils.profile_memory(model, input=(x,))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
HFAI Memory Profiler (include_children = True, sort_by = name)
==================== ================= ================ ================= ========
module name type parameter size activation size #calls
==================== ================= ================ ================= ========
AlexNet AlexNet 233.081 MiB 345.312 MiB 1
AlexNet.features Sequential 9.421 MiB 336.000 MiB 1
AlexNet.features.0 Conv2d 0.089 MiB 36.750 MiB 1
......
AlexNet.classifier.5 ReLU 0.000 MiB 1.000 MiB 1
AlexNet.classifier.6 Linear 15.629 MiB 1.000 MiB 1
==================== ================= ================ ================= ========
total unique activations: 225.906 MiB
======================================================================================
"""
if version.parse(torch.__version__) < version.parse("1.10.0"):
raise RuntimeError("hfai.utils.profile_memory 只支持 PyTorch >= 1.10")
assert isinstance(input, tuple)
assert isinstance(input_kwargs, dict)
assert sort_by in ["parameter", "activation", "peakmem", "name", "forward_time"]
assert include_children or (not show_peakmem), "include_children = False 时不支持 show_peakmem = True"
assert include_children or (not show_forward_time), "include_children = False 时不支持 show_forward_time = True"
device = next(model.parameters()).device
assert not show_peakmem or device != torch.device("cpu"), "show_peakmem = True 不支持 CPU 上的模型"
assert sort_by != "peakmem" or show_peakmem, "show_peakmem = False 时不支持 sort_by = peakmem"
assert sort_by != "forward_time" or show_forward_time
stats = []
model_name = model.__class__.__name__
time_details = defaultdict(float)
iters = 100
if show_forward_time:
# warmup
for i in range(10):
model(*input, **input_kwargs)
torch.cuda.synchronize()
time_profiler = TimeProfiler()
# backup forward, replace forward using wrap_func
for name, module in model.named_modules():
for func_name in forward_funcs:
if hasattr(module, func_name):
func = getattr(module, func_name)
setattr(module, "_hfai_orig_" + func_name, func)
setattr(module, func_name, time_profiler.wrap_func(name, module, func))
start_time = time.time()
for i in range(iters):
model(*input, **input_kwargs)
torch.cuda.synchronize()
end_time = time.time()
for name in time_profiler.event_dict:
full_name = (model_name + '.' + name) if name else model_name
for start, end in time_profiler.event_dict[name]:
time_details[full_name] += start.elapsed_time(end)
# recover forward
for name, module in model.named_modules():
for func_name in forward_funcs:
if hasattr(module, func_name):
func = getattr(module, "_hfai_orig_" + func_name)
setattr(module, func_name, func)
profiler = MemoryProfiler(model, include_children)
# backup forward, replace it using wrap_func
for name, module in model.named_modules():
for func_name in forward_funcs:
if hasattr(module, func_name):
func = getattr(module, func_name)
setattr(module, "_hfai_orig_" + func_name, func)
setattr(module, func_name, profiler.wrap_func(name, module, func))
with profiler.profile():
model(*input, **input_kwargs)
for name, (module, asize, ncalls, shapes, out_shapes, peak_mem) in profiler.module_stats.items():
full_name = (model_name + '.' + name) if name else model_name
psize = sum(p.numel() * p.element_size() for p in module.parameters(recurse=include_children))
typename = type(module).__name__
stats.append(
(full_name, typename, psize, asize, ncalls,
peak_mem, shapes, out_shapes,
time_details.get(full_name, 0) / iters)
)
# sort
if sort_by == "parameter":
stats.sort(key=lambda x: (x[2], x[3], x[0]), reverse=True)
elif sort_by == "activation":
stats.sort(key=lambda x: (x[3], x[2], x[0]), reverse=True)
elif sort_by == "peakmem":
stats.sort(key=lambda x: (x[5], x[0]), reverse=True)
elif sort_by == "forward_time":
stats.sort(key=lambda x: (x[8], x[0]), reverse=True)
table = []
for n, typename, psize, asize, ncalls, peak_mem, shapes, out_shapes, forward_time in stats:
row = [n, typename, format_size(psize),
format_size(asize), ncalls,
format_size(peak_mem), shapes,
out_shapes, f'{forward_time:.3f} ms']
table.append(row)
headers = ["module name", "type", "parameter size", "activation size", "#calls", "peak mem",
"input shape", "ouptut shape", "forward time"]
colalign = ["left", "left", "right", "right", "right", "right", "left", "left", "right"]
def pop(index):
for row in table:
row.pop(index)
headers.pop(index)
colalign.pop(index)
if not show_forward_time:
pop(8)
if not show_shapes:
pop(7), pop(6)
if not show_peakmem:
pop(5)
table = tabulate(table, headers=headers, colalign=colalign, tablefmt="rst")
total_unique_activation_size = format_size(profiler.unique_activation_size)
msg = f"HFAI Memory Profiler (include_children = {include_children}, sort_by = {sort_by})\n"
msg += str(table) + "\n"
line_width = len(str(table).split("\n")[0])
msg = line_width * "^" + "\n" + msg
n_parameters = sum(p.numel() * p.element_size() for p in model.parameters())
msg += f"total params size: {format_size(n_parameters)}\n"
msg += f"total unique activations: {total_unique_activation_size}\n"
if show_forward_time:
msg += f"total forward time per iter: {(end_time - start_time) / iters * 1000:.3f} ms\n"
msg += line_width * "=" + "\n"
print(msg, flush=True)
for name, module in model.named_modules():
for func_name in forward_funcs:
if hasattr(module, func_name):
func = getattr(module, "_hfai_orig_" + func_name)
setattr(module, func_name, func)
def format_size(size):
return f"{size / (1 << 20):.3f} MiB"
class CudaMemoryStats():
def __init__(self) -> None:
self.max_mem = 0
@contextmanager
def reset_peak_memory_stats(self):
self.max_mem = max(torch.cuda.max_memory_allocated(), self.max_mem)
prev_max_mem = self.max_mem
try:
torch.cuda.reset_peak_memory_stats()
self.max_mem = torch.cuda.max_memory_allocated()
yield
finally:
self.max_mem = max(prev_max_mem, self.max_mem, torch.cuda.max_memory_allocated())
def max_memory_allocated(self):
self.max_mem = max(torch.cuda.max_memory_allocated(), self.max_mem)
return self.max_mem
class TimeProfiler():
def __init__(self) -> None:
self.event_dict = defaultdict(list)
def wrap_func(self, name, module, func):
def wrapped_func(*args, **kwargs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# record forward time using event
start_event.record()
outputs = func(*args, **kwargs)
end_event.record()
self.event_dict[name].append((start_event, end_event))
return outputs
return wrapped_func
class MemoryProfiler():
def __init__(self, model, include_children=True) -> None:
self.module_stats = {}
self.hooks = SavedTensorsHooks(include_children)
self.include_children = include_children
self.seen_params = set()
self.seen_acts = set()
self.cuda_mem_stats = CudaMemoryStats()
for p in model.parameters():
storage = p.storage()
if storage.data_ptr() not in self.seen_params:
self.seen_params.add(storage.data_ptr())
self.seen_acts.add(storage.data_ptr())
self.unique_activation_size = 0
def wrap_func(self, name, module, func):
def pack_hook(tensor):
storage = tensor.storage()
if storage.data_ptr() not in self.seen_params:
nbytes = tensor.numel() * tensor.element_size()
self.module_stats[name][1] += nbytes
if storage.data_ptr() not in self.seen_acts:
self.seen_acts.add(storage.data_ptr())
self.unique_activation_size += storage.size() * storage.element_size()
return tensor
def unpack_hook(tensor):
return tensor
if name not in self.module_stats:
# [module, activation size, #calls, input shape, output shape, peak memory]
self.module_stats[name] = [module, 0, 0, '', '', 0]
def wrapped_func(*args, **kwargs):
mem = torch.cuda.memory_allocated()
with self.cuda_mem_stats.reset_peak_memory_stats():
with self.hooks.enable_hook(pack_hook, unpack_hook):
outputs = func(*args, **kwargs)
peak_mem = self.cuda_mem_stats.max_memory_allocated() - mem
self.module_stats[name][2] += 1
self.module_stats[name][3] = format_input_shape(args, kwargs)
self.module_stats[name][4] = format_output_shape(outputs)
self.module_stats[name][5] = peak_mem
return outputs
return wrapped_func
@contextmanager
def profile(self):
with ExitStack() as stack:
stack.enter_context(self.hooks.saved_tensors_hooks())
yield
return
def format_input_shape(args, kwargs):
shapes = []
for obj in args:
shapes.append(format_tensor(obj))
for k, v in kwargs.items():
s = format_tensor(v)
shapes.append(f"{k}={s}")
msg = ", ".join([str(s) for s in shapes])
return msg
def format_output_shape(outputs):
shapes = []
if isinstance(outputs, torch.Tensor):
shapes.append(format_tensor(outputs))
elif isinstance(outputs, (tuple, list)):
for out in outputs:
shapes.append(format_tensor(out))
else:
return str(format_tensor(outputs))
msg = ", ".join([str(s) for s in shapes])
return msg
def format_tensor(obj):
if isinstance(obj, torch.Tensor):
return tuple(obj.shape)
if isinstance(obj, (tuple, list)):
return type(obj)(format_tensor(x) for x in obj)
if isinstance(obj, dict):
return {k: format_tensor(v) for k, v in obj.items()}
if isinstance(obj, (int, float, str, bool, type(None))):
return obj
return "[UNKOWN]"
class SavedTensorsHooks():
def __init__(self, include_children=True) -> None:
self.hooks = []
self.current_hook = None
self.include_children = include_children
@contextmanager
def enable_hook(self, pack_hook, unpack_hook):
parent_hook = None
try:
if not self.include_children and len(self.hooks) > 0:
parent_hook = self.hooks.pop()
self.hooks.append((pack_hook, unpack_hook))
yield
finally:
self.hooks.pop()
if parent_hook:
self.hooks.append(parent_hook)
def pack_hook(self, tensor):
for hook in reversed(self.hooks):
tensor = hook[0](tensor)
return tensor
def unpack_hook(self, tensor):
for hook in self.hooks:
tensor = hook[1](tensor)
return tensor
@contextmanager
def saved_tensors_hooks(self):
with ExitStack() as stack:
context = torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook)
stack.enter_context(context)
yield
return