Source code for hfai.datasets.kitti
from typing import Callable, List, Optional
import io
from PIL import Image
import pickle
from ffrecord import FileReader
from .base import (
BaseDataset,
get_data_dir,
register_dataset
)
"""
Expected file organization:
[data_dir]
train.ffr
PART_00000.ffr
PART_00001.ffr
...
test.ffr
PART_00000.ffr
...
"""
[docs]@register_dataset
class KITTIObject2D(BaseDataset):
"""
这是一个目标检测数据集
KITTI 数据集由德国卡尔斯鲁厄理工学院和丰田美国技术研究院联合创办,是目前国际上最大的自动驾驶场景下的计算机视觉算法评测数据集。更多信息参考:http://www.cvlibs.net/datasets/kitti/
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者测试集(``test``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片和一个 target 作为输入,输出 transform 之后的图片和 target。(测试集没有 target)
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, target (PIL.Image.Image, list): 返回的每条样本是一个元组,包含一个RGB格式的图片,以及对应的包含位置信息的字典,例如:
.. code-block:: python
{
'type': 'Pedestrian',
'truncated': 0.0,
'occluded': 0,
'alpha': -0.2,
'bbox': [x1, y1, x2, y2],
'dimensions': [1.89, 0.48, 1.2],
'location': [1.84, 1.47, 8.41],
'rotation_y': 0.01
}
Examples:
.. code-block:: python
from hfai.datasets import KITTIObject2D
def transform(pic, target):
...
dataset = KITTIObject2D('train', transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for pic, target in loader:
# training model
"""
def __init__(
self,
split: str,
transform: Optional[Callable] = None,
check_data: bool = True,
miniset: bool = False
) -> None:
super(KITTIObject2D, self).__init__()
assert split in ["train", "test"]
self.split = split
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.fname = data_dir / "KITTI/Object2D" / f"{split}.ffr"
self.reader = FileReader(self.fname, check_data)
self.transform = transform
print('reader.n', self.reader.n)
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
bytes_ = self.reader.read(indices)
samples = []
for i, b in enumerate(bytes_):
if self.split == "train":
img_bytes, target = pickle.loads(b)
buf = io.BytesIO(img_bytes)
img = Image.open(buf).convert("RGB")
sample = (img, target)
else:
buf = io.BytesIO(b)
img = Image.open(buf).convert("RGB")
sample = img
samples.append(sample)
transformed_samples = []
if self.split == "train":
for img, seg in samples:
if self.transform:
img, seg = self.transform(img, seg)
transformed_samples.append((img, seg))
else:
for img in samples:
if self.transform:
img = self.transform(img)
transformed_samples.append(img)
return transformed_samples