Shortcuts

Source code for hfai.datasets.kitti

from typing import Callable, List, Optional

import io
from PIL import Image
import pickle
from ffrecord import FileReader
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset
)

"""
Expected file organization:

    [data_dir]
        train.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        test.ffr
            PART_00000.ffr
            ...
"""


[docs]@register_dataset class KITTIObject2D(BaseDataset): """ 这是一个目标检测数据集 KITTI 数据集由德国卡尔斯鲁厄理工学院和丰田美国技术研究院联合创办,是目前国际上最大的自动驾驶场景下的计算机视觉算法评测数据集。更多信息参考:http://www.cvlibs.net/datasets/kitti/ Args: split (str): 数据集划分形式,包括:训练集(``train``)或者测试集(``test``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片和一个 target 作为输入,输出 transform 之后的图片和 target。(测试集没有 target) check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, target (PIL.Image.Image, list): 返回的每条样本是一个元组,包含一个RGB格式的图片,以及对应的包含位置信息的字典,例如: .. code-block:: python { 'type': 'Pedestrian', 'truncated': 0.0, 'occluded': 0, 'alpha': -0.2, 'bbox': [x1, y1, x2, y2], 'dimensions': [1.89, 0.48, 1.2], 'location': [1.84, 1.47, 8.41], 'rotation_y': 0.01 } Examples: .. code-block:: python from hfai.datasets import KITTIObject2D def transform(pic, target): ... dataset = KITTIObject2D('train', transform) loader = dataset.loader(batch_size=64, num_workers=4) for pic, target in loader: # training model """ def __init__( self, split: str, transform: Optional[Callable] = None, check_data: bool = True, miniset: bool = False ) -> None: super(KITTIObject2D, self).__init__() assert split in ["train", "test"] self.split = split data_dir = get_data_dir() if miniset: data_dir = data_dir / "mini" self.fname = data_dir / "KITTI/Object2D" / f"{split}.ffr" self.reader = FileReader(self.fname, check_data) self.transform = transform print('reader.n', self.reader.n) def __len__(self): return self.reader.n def __getitem__(self, indices): bytes_ = self.reader.read(indices) samples = [] for i, b in enumerate(bytes_): if self.split == "train": img_bytes, target = pickle.loads(b) buf = io.BytesIO(img_bytes) img = Image.open(buf).convert("RGB") sample = (img, target) else: buf = io.BytesIO(b) img = Image.open(buf).convert("RGB") sample = img samples.append(sample) transformed_samples = [] if self.split == "train": for img, seg in samples: if self.transform: img, seg = self.transform(img, seg) transformed_samples.append((img, seg)) else: for img in samples: if self.transform: img = self.transform(img) transformed_samples.append(img) return transformed_samples