HFAI X PyTorch Lightning¶
PyTorch Lightning 是 PyTorch 社区中备受欢迎的框架,很多开源 AI 项目都是基于此开发。对此 hfai 针对 PyTorch Lightning 做了适配,使之可以很轻便的融入萤火集群的各种优化特性当中。具体包括了如下几点:
集群环境适配:将
hfai.pl.HFAIEnvironment()
加入插件, 支持1.5.0 <= pytorch_lightning.__version__ <= 1.7.6
绑定
numa
:strategy
使用ddp_bind_numa
或者ddp_spawn_bind_numa
, 支持1.5.0 <= pytorch_lightning.__version__ <= 1.7.6
使用
hfreduce
:strategy
使用hfreduce_bind_numa
或者hfreduce_spawn_bind_numa
, 支持1.5.0 <= pytorch_lightning.__version__ <= 1.7.6
将算子转换为
hfai
算子: 在trainer.fit
之前调用nn_to_hfai
, 支持1.5.0 <= pytorch_lightning.__version__ <= 1.7.6
自动打断重启:将
hfai.pl.ModelCheckpointHF(dirpath)
加入回调, 支持1.7.0 <= pytorch_lightning.__version__ <= 1.7.6
1.7.0
以下版本不支持从 step 恢复训练,只支持从 epoch 恢复训练
hfai.pl
使用指引¶
幻方 AI 提供 hfai.pl
封装接口,供大家使用。具体使用案例如下:
from hfai.pl import HFAIEnvironment
from hfai.pl import ModelCheckpointHF
import pytorch_lightning as pl
output_dir = 'hfai_out'
cb = ModelCheckpointHF(dirpath=output_dir) # 初始化可以接收集群打断信号的回调类
trainer = pl.Trainer(
max_epochs=3,
gpus=8,
strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
plugins=[HFAIEnvironment()], # 自动适配 HFAI 环境
callbacks=[cb] # 自动处理集群打断信号
)
model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子
ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt'
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None
trainer.fit(
model_module,
ckpt_path=ckpt_path # 自动恢复训练
)
训练数据读取¶
为了在 PyTorch Lightning 中使用 ffrecord 的 Dataloader,我们需要在 Dataloader 设置 skippable=False
:
from ffrecord.torch import Dataset, DataLoader
class MyDataset(Dataset)
...
dataset = MyDataset(...)
dataloader = DataLoader(dataset, batch_size, num_workers=num_workers, skippable=False)