HFAI X Detectron2¶
Detectron2 是 META AI 实验室 FAIR 研发的目标检测 & 语义分割的深度学习研究框架,该框架集成了 Mask-RCNN、RetinaNet、DETR 等主流目标检测/语义分割算法, 还支持了 COCO、Cityscapes 等流行的数据集。研究者可以通过 Detectron2 快速复现最新的研究工作,也支持自定义自己的模型,从而高效地验证、迭代自己的模型设计。
为了帮助用户更加便捷丝滑地使用我们的萤火二号训练平台,我们在 hfai 工具库里针对 Detectron2 做出了一些适配,提供了 FFRecord 的读取支持 和 优雅断点训练 工具, 接下来我们看看如何使用。
依赖¶
detectron2==0.6
ffrecord>=1.4.0
介绍¶
1. FFRecord 支持¶
Detectron2 使用了 iopath 来管理不同形式的文件存储方式,每种文件路径都有对应的 handler,比如以 detectron2://
开头的
文件会通过 Detectron2Handler 来读取。我们自定义了一个 FFRecordHandler 并注册到 detectron2 中,然后给数据集的路径添加
一个 ffrecord://
的前缀,这样读取数据的时候就会交由 FFRecordHandler 来处理,而 FFRecordHandler 会从打包好的 FFRecord 文件
中根据给定的路径读取对应的文件:
filename -> ffrecord://filename -> FFRecordHandler -> PackFolder.read_one(filename)
我们在 hfai 工具库里封装了一个 register_ffrecord_handler
接口,通过该接口可以在 Detectron2 里直接读取 FFRecord 格式文件。
用户只需要提前打包好数据集,然后调用该接口,就可以无缝使用 Detectron2,不需要做额外的代码修改。接下来我们以 COCO 数据集为例,介绍一下具体的使用方法。
假设我们已经有 COCO 的原始数据集,目录结构如下:
coco/
├── annotations
├── train2017
└── val2017
我们先打包整个数据集到
datasets/coco/coco.ffr
中:import ffrecord ffrecord.pack_folder("coco/", "datasets/coco/coco.ffr")
然后把 annotations 文件夹单独拷贝出来放到
datasets/coco/annotations
,现在目录结构如下:datasets └── coco ├── annotations └── coco.ffr
在训练代码中调用
register_ffrecord_handler
接口:from hfai.utils.detr2 import register_ffrecord_handler register_ffrecord_handler( ffr_file="datasets/coco/coco.ffr", ffr_prefix="coco", ) # NOTE: 必须在 import detectron2 之前调用 register_ffrecord_handler import detectron2
我们在萤火平台上提供已经打包好的 coco 数据集,放在 /public_dataset/1/ffdataset/mm_coco
路径下,用户可以链接该数据集到训练目录下:
mkdir datasets
ln -s /public_dataset/1/ffdataset/mm_coco datasets/coco
2. 优雅断点训练¶
萤火平台采用分时调度的方式管理任务,用户需要在收到打断信号后保存训练状态并挂起任务。为了帮助 Detectron2 用户更加方便地适配萤火平台分时调度的功能,
我们在 hfai 库里提供了一个 SuspendCheckpointer
钩子方法,该钩子会在每个训练 step 完成后检查是否收到了打断信号,如果收到了信号则保存训练状态并挂起任务。
使用上用户只需要添加三行代码即可:
from hfai.utils.detr2 import SuspendCheckpointer
trainer.register_hooks([SuspendCheckpointer(trainer.checkpointer)])
trainer.resume_or_load(resume=1)
3. 多机训练¶
多机训练需要通过环境变量设置机器数量、机器编号等:
import os
from detectron2.engine import launch
num_gpus = 8
machine_rank = int(os.getenv("RANK", "0"))
num_machines = int(os.getenv("WORLD_SIZE", "1"))
addr = os.getenv("MASTER_ADDR", "127.0.0.1")
port = os.getenv("MASTER_PORT", 2222)
dist_url = f"tcp://{addr}:{port}"
if __name__ == "__main__":
launch(
main,
num_gpus,
num_machines=num_machines,
machine_rank=machine_rank,
dist_url=dist_url,
args=(args,),
)