Source code for hfai.datasets.nuscenes
from typing import Callable, List, Optional
import numpy as np
import torch
import pickle
from ffrecord import FileReader
from ffrecord.torch import DataLoader
from .base import (
BaseDataset,
get_data_dir,
register_dataset,
)
"""
Expected file organization:
[data_dir]
train.ffr
PART_00000.ffr
PART_00001.ffr
val.ffr
PART_00000.ffr
PART_00001.ffr
"""
class NuScenesReader:
"""
NuScenes 数据集读取接口
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
"""
def __init__(self, split: str, check_data: bool = True, miniset: bool = False):
assert split in ["train", "val"]
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.data_dir = data_dir / "NuScenes"
self.fname = self.data_dir / f"{split}.ffr"
self.reader = FileReader(self.fname, check_data)
self.default_keys = [
'imgs', 'trans', 'rots', 'intrins',
'lidar_data', 'lidar_mask', 'car_trans',
'yaw_pitch_roll', 'semantic_gt',
'instance_gt', 'direction_gt'
]
self.detection_keys = [
"sweeps", "gt_boxes", "gt_names",
"gt_velocity", "num_lidar_pts",
"num_radar_pts", "valid_flag"
]
self.seletected_keys = self.default_keys
def __len__(self):
return self.reader.n
def read(self, indices: List[int]):
"""
读取数据字典
Args:
indices: 样本的索引
Returns:
dict 格式的数据字典
"""
# read datadict
bytes_ = self.reader.read(indices)
samples = []
for x in bytes_:
data_dict = pickle.loads(x)
if self.seletected_keys is not None:
data_dict = {k: v for k, v in data_dict.items() if k in self.seletected_keys}
samples.append(data_dict)
return samples
[docs]@register_dataset
class NuScenes(BaseDataset):
"""
这是一个用于自动驾驶任务的 NuScenes 数据集
更多信息参考:https://www.nuscenes.org/
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对数据字典进行 transfrom,接受数据字典作为输入,输出 transform 之后的数据字典
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
data_dict (dict): 返回的每条样本是一个数据字典,如下例所示 (参考 `官网数据解释 <https://www.nuscenes.org/nuscenes>`_):
.. code-block:: python
{
"imgs": list of PIL.Image.Image, # 六个相机的视觉图像
"trans": torch.Tensor, # 机身坐标系下,六个相机的偏移向量
"rots": torch.Tensor, # 机身坐标系下,六个相机的旋转向量
"intrins": torch.Tensor, # 六个相机的内参
"lidar_data": torch.Tensor, # 机身坐标系下,激光雷达数据
"lidar_mask": torch.Tensor, # 标识激光雷达填充数据的掩码
"car_trans": torch.Tensor, # 世界坐标系下,机身偏移向量
"yaw_pitch_roll": torch.Tensor, # 世界坐标系下,机身的旋转欧拉角
"semantic_gt": torch.Tensor, # 语义分割标注
"instance_gt": torch.Tensor, # 实例分割标注
"direction_gt": torch.Tensor # 方向分割标注,将360度划分为36份
"sweeps": list of dict # 非没有标注的中间帧的传感器数据
"gt_boxes": torch.Tensor # 7 个自由度的 3D 包围框,一个 Nx7 数组
"gt_names": list of string # 3D 包围框的类别,一个 1xN 数组
"gt_velocity": torch.Tensor # 3D 包围框的速度(由于不准确,没有垂直测量),一个 Nx2 数组
"num_lidar_pts": torch.Tensor # 每个 3D 包围框中包含的激光雷达点数
"num_radar_pts": torch.Tensor # 每个 3D 包围框中包含的雷达点数
"valid_flag": torch.Tensor # 每个包围框是否有效。将包含至少一个激光雷达或雷达点的 3D 框作为有效框
}
NOTE:
detection 相关的 key,每个样本的数据长度不同。例如 'gt_boxes',第 0 条数据长度为 10,而第 1 条数据长度为 20。
Examples:
.. code-block:: python
from hfai.datasets import NuScenes
from torchvision import transforms
import torch
def transform(data_dict):
trans = transforms.Compose([transforms.ToTensor()])
data_dict['imgs'] = torch.stack([
trans(img) for img in data_dict['imgs']
])
return data_dict
split = 'train' # or val
dataset = NuScenes(split, transform)
# segmentation_task
loader = dataset.loader(batch_size=64, num_workers=4)
for i, data_dict in enumerate(loader):
print('keys', data_dict.keys())
'''
output: keys dict_keys([
'imgs', 'trans', 'rots', 'intrins', 'lidar_data', 'lidar_mask', 'car_trans',
'yaw_pitch_roll', 'semantic_gt', 'instance_gt', 'direction_gt'
])
'''
# detection_task
dataset.set_task('detection')
for i, data_dict in enumerate(loader):
print('detection keys', data_dict.keys())
'''
output: keys dict_keys([
'imgs', 'trans', 'rots', 'intrins',
'lidar_data', 'lidar_mask', 'car_trans', 'yaw_pitch_roll', 'semantic_gt',
'instance_gt', 'direction_gt', 'sweeps', 'gt_boxes', 'gt_names',
'gt_velocity', 'num_lidar_pts', 'num_radar_pts', 'valid_flag'
])
'''
"""
def __init__(
self,
split: str,
transform: Optional[Callable] = None,
check_data: bool = True,
miniset: bool = False
) -> None:
super(NuScenes, self).__init__()
self.split = split
self.reader = NuScenesReader(split, check_data, miniset)
self.transform = transform
self.task = 'segmentation'
def set_selected_keys(self, keys: List):
assert len(keys) != 0, 'plz input not empty keys'
self.reader.seletected_keys = keys
def set_task(self, task: str):
assert task in ['segmentation', 'detection']
self.task = task
if task == 'segmentation':
self.reader.seletected_keys = self.reader.default_keys
elif task == 'detection':
# None means all key
self.reader.seletected_keys = None
def __len__(self):
return len(self.reader)
def __getitem__(self, indices):
samples = self.reader.read(indices)
transformed_samples = []
for data_dict in samples:
if self.transform:
data_dict = self.transform(data_dict)
transformed_samples.append(data_dict)
return transformed_samples
def loader(self, *args, **kwargs) -> DataLoader:
if 'collate_fn' not in kwargs:
kwargs['collate_fn'] = self.collate_fn
return DataLoader(self, *args, **kwargs)
def collate_fn(self, data):
ret_dict = {}
for sample in data:
for k, v in sample.items():
if k not in ret_dict:
ret_dict[k] = list()
ret_dict[k].append(v)
for k, v in ret_dict.items():
if k in self.reader.default_keys:
ret_dict[k] = torch.stack(v, dim=0)
return ret_dict