# HFAI X PyTorch Lightning PyTorch Lightning 是 PyTorch 社区中备受欢迎的框架,很多开源 AI 项目都是基于此开发。对此 hfai 针对 PyTorch Lightning 做了适配,使之可以很轻便的融入萤火集群的各种优化特性当中。具体包括了如下几点: - 集群环境适配:将 `hfai.pl.HFAIEnvironment()` 加入插件, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6` - 绑定 `numa`:`strategy` 使用 `ddp_bind_numa` 或者 `ddp_spawn_bind_numa`, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6` - 使用 `hfreduce`:`strategy` 使用 `hfreduce_bind_numa` 或者 `hfreduce_spawn_bind_numa`, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6` - 将算子转换为 `hfai` 算子: 在 `trainer.fit` 之前调用 `nn_to_hfai`, 支持 `1.5.0 <= pytorch_lightning.__version__ <= 1.7.6` - 自动打断重启:将 `hfai.pl.ModelCheckpointHF(dirpath)` 加入回调, 支持 `1.7.0 <= pytorch_lightning.__version__ <= 1.7.6` - `1.7.0` 以下版本不支持从 step 恢复训练,只支持从 epoch 恢复训练 ## `hfai.pl` 使用指引 幻方 AI 提供 [`hfai.pl` 封装接口](../api/pl.rst),供大家使用。具体使用案例如下: ```python from hfai.pl import HFAIEnvironment from hfai.pl import ModelCheckpointHF import pytorch_lightning as pl output_dir = 'hfai_out' cb = ModelCheckpointHF(dirpath=output_dir) # 初始化可以接收集群打断信号的回调类 trainer = pl.Trainer( max_epochs=3, gpus=8, strategy="ddp_bind_numa", # hfai 支持 ddp_bind_numa, ddp_spawn_bind_numa, hfreduce_bind_numa, hfreduce_spawn_bind_numa plugins=[HFAIEnvironment()], # 自动适配 HFAI 环境 callbacks=[cb] # 自动处理集群打断信号 ) model_module = nn_to_hfai(ToyNetModule()) # 将算子转换为 hfai 算子 ckpt_path = f'{output_dir}/{cb.CHECKPOINT_NAME_SUSPEND}.ckpt' ckpt_path = ckpt_path if os.path.exists(ckpt_path) else None trainer.fit( model_module, ckpt_path=ckpt_path # 自动恢复训练 ) ``` ## 训练数据读取 为了在 PyTorch Lightning 中使用 ffrecord 的 Dataloader,我们需要在 Dataloader 设置 `skippable=False`: ```python from ffrecord.torch import Dataset, DataLoader class MyDataset(Dataset) ... dataset = MyDataset(...) dataloader = DataLoader(dataset, batch_size, num_workers=num_workers, skippable=False) ```