Shortcuts

Source code for hfai.datasets.isc2021

from typing import Callable, Optional

import io
import pandas as pd
from PIL import Image
from ffrecord import FileReader
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset
)

"""
Expected file organization:

    [data_dir]
        public_ground_truth_50K.csv

        trainining.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        reference.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
        query.ffr
            PART_00000.ffr
            PART_00001.ffr
            ...
"""


[docs]@register_dataset class ISC2021(BaseDataset): """ 这是一个无监督学习的数据集 该数据集是 Facebook AI 在 NeurIPS 2021 举办的图像相似检索大赛上所开源的百万量级的数据集。更多信息参考:https://github.com/facebookresearch/isc2021 Args: split (str): 数据集划分形式,包括:训练集(``trainining``)、参考集(``reference``)或者查询集(``query``) transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: img (PIL.Image.Image): 返回的每条样本是一张RGB格式的图片 Examples: .. code-block:: python from hfai.datasets import ISC2021 from torchvision import transforms transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) dataset = ISC2021(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) for imgs in loader: # training model """ def __init__( self, split: str, transform: Optional[Callable] = None, check_data: bool = True, miniset: bool = False ) -> None: super(ISC2021, self).__init__() splits = ["training", "reference", "query"] assert split in splits, "Available splits " + str(splits) self.split = split self.transform = transform data_dir = get_data_dir() if miniset: data_dir = data_dir / "mini" self.data_dir = data_dir / "ISC2021" self.fname = self.data_dir / f"{split}.ffr" self.reader = FileReader(self.fname, check_data) data = pd.read_csv(self.data_dir / "public_ground_truth_50K.csv") self.gt = data.to_numpy() def __len__(self): return self.reader.n def __getitem__(self, indices): imgs_bytes = self.reader.read(indices) imgs = [] for bytes_ in imgs_bytes: buf = io.BytesIO(bytes_) img = Image.open(buf).convert("RGB") imgs.append(img) transformed_imgs = [] for img in imgs: if self.transform: img = self.transform(img) transformed_imgs.append(img) return transformed_imgs