Shortcuts

Source code for hfai.datasets.nuscenes

from typing import Callable, List, Optional

import numpy as np
import torch
import pickle
from ffrecord import FileReader
from ffrecord.torch import DataLoader
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset,
)

"""
Expected file organization:

    [data_dir]
        train.ffr
            PART_00000.ffr
            PART_00001.ffr
        val.ffr
            PART_00000.ffr
            PART_00001.ffr
"""


class NuScenesReader:
    """
    NuScenes 数据集读取接口

    Args:
        split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
        check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
        miniset (bool): 是否使用 mini 集合(默认为 ``False``)
    """

    def __init__(self, split: str, check_data: bool = True, miniset: bool = False):
        assert split in ["train", "val"]
        data_dir = get_data_dir()
        if miniset:
            data_dir = data_dir / "mini"
        self.data_dir = data_dir / "NuScenes"
        self.fname = self.data_dir / f"{split}.ffr"
        self.reader = FileReader(self.fname, check_data)
        self.default_keys = [
            'imgs', 'trans', 'rots', 'intrins',
            'lidar_data', 'lidar_mask', 'car_trans',
            'yaw_pitch_roll', 'semantic_gt',
            'instance_gt', 'direction_gt'
        ]
        self.detection_keys = [
            "sweeps", "gt_boxes", "gt_names",
            "gt_velocity", "num_lidar_pts",
            "num_radar_pts", "valid_flag"

        ]
        self.seletected_keys = self.default_keys

    def __len__(self):
        return self.reader.n

    def read(self, indices: List[int]):
        """
        读取数据字典

        Args:
            indices: 样本的索引

        Returns:
            dict 格式的数据字典
        """
        # read datadict
        bytes_ = self.reader.read(indices)
        samples = []
        for x in bytes_:
            data_dict = pickle.loads(x)
            if self.seletected_keys is not None:
                data_dict = {k: v for k, v in data_dict.items() if k in self.seletected_keys}
            samples.append(data_dict)

        return samples


[docs]@register_dataset class NuScenes(BaseDataset): """ 这是一个用于自动驾驶任务的 NuScenes 数据集 更多信息参考:https://www.nuscenes.org/ Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对数据字典进行 transfrom,接受数据字典作为输入,输出 transform 之后的数据字典 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: data_dict (dict): 返回的每条样本是一个数据字典,如下例所示 (参考 `官网数据解释 <https://www.nuscenes.org/nuscenes>`_): .. code-block:: python { "imgs": list of PIL.Image.Image, # 六个相机的视觉图像 "trans": torch.Tensor, # 机身坐标系下,六个相机的偏移向量 "rots": torch.Tensor, # 机身坐标系下,六个相机的旋转向量 "intrins": torch.Tensor, # 六个相机的内参 "lidar_data": torch.Tensor, # 机身坐标系下,激光雷达数据 "lidar_mask": torch.Tensor, # 标识激光雷达填充数据的掩码 "car_trans": torch.Tensor, # 世界坐标系下,机身偏移向量 "yaw_pitch_roll": torch.Tensor, # 世界坐标系下,机身的旋转欧拉角 "semantic_gt": torch.Tensor, # 语义分割标注 "instance_gt": torch.Tensor, # 实例分割标注 "direction_gt": torch.Tensor # 方向分割标注,将360度划分为36份 "sweeps": list of dict # 非没有标注的中间帧的传感器数据 "gt_boxes": torch.Tensor # 7 个自由度的 3D 包围框,一个 Nx7 数组 "gt_names": list of string # 3D 包围框的类别,一个 1xN 数组 "gt_velocity": torch.Tensor # 3D 包围框的速度(由于不准确,没有垂直测量),一个 Nx2 数组 "num_lidar_pts": torch.Tensor # 每个 3D 包围框中包含的激光雷达点数 "num_radar_pts": torch.Tensor # 每个 3D 包围框中包含的雷达点数 "valid_flag": torch.Tensor # 每个包围框是否有效。将包含至少一个激光雷达或雷达点的 3D 框作为有效框 } NOTE: detection 相关的 key,每个样本的数据长度不同。例如 'gt_boxes',第 0 条数据长度为 10,而第 1 条数据长度为 20。 Examples: .. code-block:: python from hfai.datasets import NuScenes from torchvision import transforms import torch def transform(data_dict): trans = transforms.Compose([transforms.ToTensor()]) data_dict['imgs'] = torch.stack([ trans(img) for img in data_dict['imgs'] ]) return data_dict split = 'train' # or val dataset = NuScenes(split, transform) # segmentation_task loader = dataset.loader(batch_size=64, num_workers=4) for i, data_dict in enumerate(loader): print('keys', data_dict.keys()) ''' output: keys dict_keys([ 'imgs', 'trans', 'rots', 'intrins', 'lidar_data', 'lidar_mask', 'car_trans', 'yaw_pitch_roll', 'semantic_gt', 'instance_gt', 'direction_gt' ]) ''' # detection_task dataset.set_task('detection') for i, data_dict in enumerate(loader): print('detection keys', data_dict.keys()) ''' output: keys dict_keys([ 'imgs', 'trans', 'rots', 'intrins', 'lidar_data', 'lidar_mask', 'car_trans', 'yaw_pitch_roll', 'semantic_gt', 'instance_gt', 'direction_gt', 'sweeps', 'gt_boxes', 'gt_names', 'gt_velocity', 'num_lidar_pts', 'num_radar_pts', 'valid_flag' ]) ''' """ def __init__( self, split: str, transform: Optional[Callable] = None, check_data: bool = True, miniset: bool = False ) -> None: super(NuScenes, self).__init__() self.split = split self.reader = NuScenesReader(split, check_data, miniset) self.transform = transform self.task = 'segmentation' def set_selected_keys(self, keys: List): assert len(keys) != 0, 'plz input not empty keys' self.reader.seletected_keys = keys def set_task(self, task: str): assert task in ['segmentation', 'detection'] self.task = task if task == 'segmentation': self.reader.seletected_keys = self.reader.default_keys elif task == 'detection': # None means all key self.reader.seletected_keys = None def __len__(self): return len(self.reader) def __getitem__(self, indices): samples = self.reader.read(indices) transformed_samples = [] for data_dict in samples: if self.transform: data_dict = self.transform(data_dict) transformed_samples.append(data_dict) return transformed_samples def loader(self, *args, **kwargs) -> DataLoader: if 'collate_fn' not in kwargs: kwargs['collate_fn'] = self.collate_fn return DataLoader(self, *args, **kwargs) def collate_fn(self, data): ret_dict = {} for sample in data: for k, v in sample.items(): if k not in ret_dict: ret_dict[k] = list() ret_dict[k].append(v) for k, v in ret_dict.items(): if k in self.reader.default_keys: ret_dict[k] = torch.stack(v, dim=0) return ret_dict