Shortcuts

HFAI X Detectron2

Detectron2 是 META AI 实验室 FAIR 研发的目标检测 & 语义分割的深度学习研究框架,该框架集成了 Mask-RCNN、RetinaNet、DETR 等主流目标检测/语义分割算法, 还支持了 COCO、Cityscapes 等流行的数据集。研究者可以通过 Detectron2 快速复现最新的研究工作,也支持自定义自己的模型,从而高效地验证、迭代自己的模型设计。

为了帮助用户更加便捷丝滑地使用我们的萤火二号训练平台,我们在 hfai 工具库里针对 Detectron2 做出了一些适配,提供了 FFRecord 的读取支持优雅断点训练 工具, 接下来我们看看如何使用。

依赖

  • detectron2==0.6

  • ffrecord>=1.4.0

介绍

1. FFRecord 支持

Detectron2 使用了 iopath 来管理不同形式的文件存储方式,每种文件路径都有对应的 handler,比如以 detectron2:// 开头的 文件会通过 Detectron2Handler 来读取。我们自定义了一个 FFRecordHandler 并注册到 detectron2 中,然后给数据集的路径添加 一个 ffrecord:// 的前缀,这样读取数据的时候就会交由 FFRecordHandler 来处理,而 FFRecordHandler 会从打包好的 FFRecord 文件 中根据给定的路径读取对应的文件:

filename -> ffrecord://filename -> FFRecordHandler -> PackFolder.read_one(filename)

我们在 hfai 工具库里封装了一个 register_ffrecord_handler 接口,通过该接口可以在 Detectron2 里直接读取 FFRecord 格式文件。 用户只需要提前打包好数据集,然后调用该接口,就可以无缝使用 Detectron2,不需要做额外的代码修改。接下来我们以 COCO 数据集为例,介绍一下具体的使用方法。

假设我们已经有 COCO 的原始数据集,目录结构如下:

coco/
├── annotations
├── train2017
└── val2017
  1. 我们先打包整个数据集到 datasets/coco/coco.ffr 中:

    import ffrecord
    ffrecord.pack_folder("coco/", "datasets/coco/coco.ffr")
    
  2. 然后把 annotations 文件夹单独拷贝出来放到 datasets/coco/annotations,现在目录结构如下:

    datasets
    └── coco
        ├── annotations
        └── coco.ffr
    
  3. 在训练代码中调用 register_ffrecord_handler 接口:

    from hfai.utils.detr2 import register_ffrecord_handler
    register_ffrecord_handler(
        ffr_file="datasets/coco/coco.ffr",
        ffr_prefix="coco",
    )
    
    # NOTE: 必须在 import detectron2 之前调用 register_ffrecord_handler
    import detectron2
    

我们在萤火平台上提供已经打包好的 coco 数据集,放在 /public_dataset/1/ffdataset/mm_coco 路径下,用户可以链接该数据集到训练目录下:

mkdir datasets
ln -s /public_dataset/1/ffdataset/mm_coco datasets/coco

2. 优雅断点训练

萤火平台采用分时调度的方式管理任务,用户需要在收到打断信号后保存训练状态并挂起任务。为了帮助 Detectron2 用户更加方便地适配萤火平台分时调度的功能, 我们在 hfai 库里提供了一个 SuspendCheckpointer 钩子方法,该钩子会在每个训练 step 完成后检查是否收到了打断信号,如果收到了信号则保存训练状态并挂起任务。

使用上用户只需要添加三行代码即可:

from hfai.utils.detr2 import SuspendCheckpointer
trainer.register_hooks([SuspendCheckpointer(trainer.checkpointer)])
trainer.resume_or_load(resume=1)

3. 多机训练

多机训练需要通过环境变量设置机器数量、机器编号等:

import os
from detectron2.engine import launch

num_gpus = 8
machine_rank = int(os.getenv("RANK", "0"))
num_machines = int(os.getenv("WORLD_SIZE", "1"))
addr = os.getenv("MASTER_ADDR", "127.0.0.1")
port = os.getenv("MASTER_PORT", 2222)
dist_url = f"tcp://{addr}:{port}"

if __name__ == "__main__":
    launch(
        main,
        num_gpus,
        num_machines=num_machines,
        machine_rank=machine_rank,
        dist_url=dist_url,
        args=(args,),
    )