Shortcuts

Source code for hfai.datasets.ade20k

from typing import Callable, Optional

import io
from PIL import Image
import pickle
from ffrecord import FileReader
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset
)

"""
Expected file organization:

    [data_dir]
        train.ffr
            PART_00000.ffr
            PART_00001.ffr
        val.ffr
            PART_00000.ffr
"""


[docs]@register_dataset class ADE20k(BaseDataset): """ 这是一个语义分割的公开数据集 该数据集拥有超过 25,000 张图像,这些图像用开放字典标签集密集注释。具体参考官网:https://groups.csail.mit.edu/vision/datasets/ADE20K/ Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和标注进行 transfrom,接受一张图片和一个 segmentation mask 作为输入,输出 transform 之后的图片和 segmentation mask check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, seg_mask (PIL.Image.Image, PIL.Image.Image): 返回的每条样本是一个元组,包含一个RGB格式的图片,一个L格式的图片 Examples: .. code-block:: python from hfai.datasets import ADE20k def transform(pic, seg_mask): ... dataset = ADE20k(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) for pic, seg_mask in loader: # training model """ def __init__( self, split: str, transform: Optional[Callable] = None, check_data: bool = True, miniset: bool = False ) -> None: super(ADE20k, self).__init__() assert split in ["train", "val"] self.split = split self.transform = transform data_dir = get_data_dir() if miniset: data_dir = data_dir / "mini" self.fname = data_dir / "ADE20k" / f"{split}.ffr" self.reader = FileReader(self.fname, check_data) def __len__(self): return self.reader.n def __getitem__(self, indices): bytes_ = self.reader.read(indices) samples = [] for b in bytes_: img_bytes, seg_bytes = pickle.loads(b) buf = io.BytesIO(img_bytes) img = Image.open(buf).convert("RGB") buf = io.BytesIO(seg_bytes) seg = Image.open(buf).convert("L") samples.append((img, seg)) transformed_samples = [] for img, seg in samples: if self.transform: img, seg = self.transform(img, seg) transformed_samples.append((img, seg)) return transformed_samples