Source code for hfai.datasets.ade20k
from typing import Callable, Optional
import io
from PIL import Image
import pickle
from ffrecord import FileReader
from .base import (
BaseDataset,
get_data_dir,
register_dataset
)
"""
Expected file organization:
[data_dir]
train.ffr
PART_00000.ffr
PART_00001.ffr
val.ffr
PART_00000.ffr
"""
[docs]@register_dataset
class ADE20k(BaseDataset):
"""
这是一个语义分割的公开数据集
该数据集拥有超过 25,000 张图像,这些图像用开放字典标签集密集注释。具体参考官网:https://groups.csail.mit.edu/vision/datasets/ADE20K/
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片和一个 segmentation mask 作为输入,输出 transform 之后的图片和 segmentation mask
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, seg_mask (PIL.Image.Image, PIL.Image.Image): 返回的每条样本是一个元组,包含一个RGB格式的图片,一个L格式的图片
Examples:
.. code-block:: python
from hfai.datasets import ADE20k
def transform(pic, seg_mask):
...
dataset = ADE20k(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for pic, seg_mask in loader:
# training model
"""
def __init__(
self,
split: str,
transform: Optional[Callable] = None,
check_data: bool = True,
miniset: bool = False
) -> None:
super(ADE20k, self).__init__()
assert split in ["train", "val"]
self.split = split
self.transform = transform
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.fname = data_dir / "ADE20k" / f"{split}.ffr"
self.reader = FileReader(self.fname, check_data)
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
bytes_ = self.reader.read(indices)
samples = []
for b in bytes_:
img_bytes, seg_bytes = pickle.loads(b)
buf = io.BytesIO(img_bytes)
img = Image.open(buf).convert("RGB")
buf = io.BytesIO(seg_bytes)
seg = Image.open(buf).convert("L")
samples.append((img, seg))
transformed_samples = []
for img, seg in samples:
if self.transform:
img, seg = self.transform(img, seg)
transformed_samples.append((img, seg))
return transformed_samples