Source code for hfai.datasets.isc2021
from typing import Callable, Optional
import io
import pandas as pd
from PIL import Image
from ffrecord import FileReader
from .base import (
BaseDataset,
get_data_dir,
register_dataset
)
"""
Expected file organization:
[data_dir]
public_ground_truth_50K.csv
trainining.ffr
PART_00000.ffr
PART_00001.ffr
...
reference.ffr
PART_00000.ffr
PART_00001.ffr
...
query.ffr
PART_00000.ffr
PART_00001.ffr
...
"""
[docs]@register_dataset
class ISC2021(BaseDataset):
"""
这是一个无监督学习的数据集
该数据集是 Facebook AI 在 NeurIPS 2021 举办的图像相似检索大赛上所开源的百万量级的数据集。更多信息参考:https://github.com/facebookresearch/isc2021
Args:
split (str): 数据集划分形式,包括:训练集(``trainining``)、参考集(``reference``)或者查询集(``query``)
transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
img (PIL.Image.Image): 返回的每条样本是一张RGB格式的图片
Examples:
.. code-block:: python
from hfai.datasets import ISC2021
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
dataset = ISC2021(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for imgs in loader:
# training model
"""
def __init__(
self,
split: str,
transform: Optional[Callable] = None,
check_data: bool = True,
miniset: bool = False
) -> None:
super(ISC2021, self).__init__()
splits = ["training", "reference", "query"]
assert split in splits, "Available splits " + str(splits)
self.split = split
self.transform = transform
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.data_dir = data_dir / "ISC2021"
self.fname = self.data_dir / f"{split}.ffr"
self.reader = FileReader(self.fname, check_data)
data = pd.read_csv(self.data_dir / "public_ground_truth_50K.csv")
self.gt = data.to_numpy()
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
imgs_bytes = self.reader.read(indices)
imgs = []
for bytes_ in imgs_bytes:
buf = io.BytesIO(bytes_)
img = Image.open(buf).convert("RGB")
imgs.append(img)
transformed_imgs = []
for img in imgs:
if self.transform:
img = self.transform(img)
transformed_imgs.append(img)
return transformed_imgs