• Docs >
  • PyTorch Lightning 适配 hfai
Shortcuts

PyTorch Lightning 适配 hfai

PyTorch Lightning 是 PyTorch 社区中备受欢迎的框架,很多开源 AI 项目都是基于此开发。对此 hfai 针对 PyTorch Lightning 做了适配,使之可以很轻便的融入萤火集群的各种优化特性当中。具体包括了如下几点:

  • 集群环境适配:将 hfai.pl.HFAIEnvironment() 加入插件, 支持 1.5.0 <= pytorch_lightning.__version__ <= 1.7.6

  • 绑定 numastrategy 使用 ddp_bind_numa 或者 ddp_spawn_bind_numa, 支持 1.5.0 <= pytorch_lightning.__version__ <= 1.7.6

  • 使用 hfreducestrategy 使用 hfreduce_bind_numa 或者 hfreduce_spawn_bind_numa, 支持 1.5.0 <= pytorch_lightning.__version__ <= 1.7.6

  • 将算子转换为 hfai 算子: 在 trainer.fit 之前调用 nn_to_hfai, 支持 1.5.0 <= pytorch_lightning.__version__ <= 1.7.6

  • 自动打断重启:将 hfai.pl.ModelCheckpointHF(dirpath) 加入回调, 支持 1.7.0 <= pytorch_lightning.__version__ <= 1.7.6

    • 1.7.0 以下版本不支持从 step 恢复训练,只支持从 epoch 恢复训练

hfai.pl 使用指引

幻方 AI 提供 hfai.pl 封装接口,供大家使用。具体使用案例如下:

from hfai.pl import HFAIEnvironment
from hfai.pl import ModelCheckpointHF
import pytorch_lightning as pl

output_dir = 'hfai_out'
cb = ModelCheckpointHF(dirpath=output_dir) # 初始化可以接收集群打断信号的回调类
trainer = pl.Trainer(
    max_epochs=3,
    gpus=8,
    strategy="ddp_bind_numa",  # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa
    plugins=[HFAIEnvironment()],  # 自动适配 HFAI 环境
    callbacks=[cb] # 自动处理集群打断信号
)
model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子

ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt'
ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None

trainer.fit(
    model_module,
    ckpt_path=ckpt_path # 自动恢复训练
)

训练数据读取

为了在 PyTorch Lightning 中使用 ffrecord 的 Dataloader,我们需要在 Dataloader 设置 skippable=False:

from ffrecord.torch import Dataset, DataLoader

class MyDataset(Dataset)
    ...

dataset = MyDataset(...)
dataloader = DataLoader(dataset, batch_size, num_workers=num_workers, skippable=False)