Source code for haiscale.timer
from collections import defaultdict
import functools
from tabulate import tabulate
import time
import torch
[docs]class CudaTimer():
"""
用来统计 haiscale 中不同部分的时间开销
Examples:
.. code-block:: python
from haiscale.pipeline import PipeDream
from haiscale.timer cuda_timer
gpt = PipeDream(...)
cuda_timer.start()
for i in range(steps):
loss, _ = gpt.forward_backward(x, criterion=criterion, labels=(x,))
cuda_timer.stop()
cuda_timer.print_statistics()
打印出来的结果如下:
.. code-block::
name ncals time (ms)
-------------- ------- -----------
backward_chunk 3200 18298
forward_chunk 3200 7798
recv 100 8050
recv_meta 100 8032
send 100 43
send_meta 100 16
sendrecv 3100 15425
sendrecv_meta 3100 14474
"""
def __init__(self) -> None:
self.enabled = False
self.reset()
def reset(self):
self.times = defaultdict(float)
self.ncalls = defaultdict(int)
self.send_bytes = defaultdict(int)
self.recv_bytes = defaultdict(int)
def record(self, name):
def time_decorator(f):
@functools.wraps(f)
def new_f(*args, **kwargs):
if not self.enabled:
return f(*args, **kwargs)
torch.cuda.synchronize()
t0 = time.perf_counter()
out = f(*args, **kwargs)
torch.cuda.synchronize()
t = time.perf_counter() - t0
self.times[name] += t * 1000
self.ncalls[name] += 1
return out
return new_f
return time_decorator
def record_send(self, name, tensors):
if self.enabled:
nbytes = sum(x.numel() * x.element_size() for x in tensors)
self.send_bytes[name] += nbytes
def record_recv(self, name, tensors):
if self.enabled:
nbytes = sum(x.numel() * x.element_size() for x in tensors)
self.recv_bytes[name] += nbytes
def start(self):
self.reset()
self.enabled = True
def stop(self):
self.enabled = False
def print_statistics(self):
tab = self.format_statistics()
print(tab, flush=True)
def format_statistics(self):
data = [(k, self.ncalls[k], int(v), self.send_bytes[k], self.recv_bytes[k]) for k, v in self.times.items()]
data.sort(key=lambda x: x[0])
tab = tabulate(data, headers=["name", "ncals", "time (ms)", "send_bytes", "recv_bytes"])
return tab
def get_statistics(self):
data = [(k, self.ncalls[k], int(v), self.send_bytes[k], self.recv_bytes[k]) for k, v in self.times.items()]
return data
cuda_timer = CudaTimer()