Shortcuts

Source code for haiscale.timer

from collections import defaultdict
import functools
from tabulate import tabulate
import time
import torch


[docs]class CudaTimer(): """ 用来统计 haiscale 中不同部分的时间开销 Examples: .. code-block:: python from haiscale.pipeline import PipeDream from haiscale.timer cuda_timer gpt = PipeDream(...) cuda_timer.start() for i in range(steps): loss, _ = gpt.forward_backward(x, criterion=criterion, labels=(x,)) cuda_timer.stop() cuda_timer.print_statistics() 打印出来的结果如下: .. code-block:: name ncals time (ms) -------------- ------- ----------- backward_chunk 3200 18298 forward_chunk 3200 7798 recv 100 8050 recv_meta 100 8032 send 100 43 send_meta 100 16 sendrecv 3100 15425 sendrecv_meta 3100 14474 """ def __init__(self) -> None: self.enabled = False self.reset() def reset(self): self.times = defaultdict(float) self.ncalls = defaultdict(int) self.send_bytes = defaultdict(int) self.recv_bytes = defaultdict(int) def record(self, name): def time_decorator(f): @functools.wraps(f) def new_f(*args, **kwargs): if not self.enabled: return f(*args, **kwargs) torch.cuda.synchronize() t0 = time.perf_counter() out = f(*args, **kwargs) torch.cuda.synchronize() t = time.perf_counter() - t0 self.times[name] += t * 1000 self.ncalls[name] += 1 return out return new_f return time_decorator def record_send(self, name, tensors): if self.enabled: nbytes = sum(x.numel() * x.element_size() for x in tensors) self.send_bytes[name] += nbytes def record_recv(self, name, tensors): if self.enabled: nbytes = sum(x.numel() * x.element_size() for x in tensors) self.recv_bytes[name] += nbytes def start(self): self.reset() self.enabled = True def stop(self): self.enabled = False def print_statistics(self): tab = self.format_statistics() print(tab, flush=True) def format_statistics(self): data = [(k, self.ncalls[k], int(v), self.send_bytes[k], self.recv_bytes[k]) for k, v in self.times.items()] data.sort(key=lambda x: x[0]) tab = tabulate(data, headers=["name", "ncals", "time (ms)", "send_bytes", "recv_bytes"]) return tab def get_statistics(self): data = [(k, self.ncalls[k], int(v), self.send_bytes[k], self.recv_bytes[k]) for k, v in self.times.items()] return data
cuda_timer = CudaTimer()