Shortcuts

Source code for hfai.datasets.googlecc

import pickle
from ffrecord import FileReader
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset,
)

"""
Expected file organization:

    [data_dir]
        train/
            Train_GCC-training_output_washed_mp_000.ffr
            ...
            Train_GCC-training_output_washed_mp_095.ffr
        val/
            Validation_GCC-1.1.0-Validation_output_washed_mp_0.ffr

"""


[docs]@register_dataset class GoogleConceptualCaption(BaseDataset): """ 这是一个用于多模态训练的数据集 该数据集是一个子数据集,从 3318333 个 “图片-字幕” 对中随机采样了 2850879 个。更多信息参考:https://ai.google.com/research/ConceptualCaptions/ Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片和文本进行 transfrom,接受一张图片和一段文本作为输入,输出 transform 之后的结果 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: pic, text (PIL.Image.Image, str): 返回的每个样本是一个元组,包含一个RGB格式的图片,一段字幕文本 Examples: .. code-block:: python from hfai.datasets import GoogleConceptualCaption from torchvision import transforms img_transform = transforms.Compose([ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) tokenize = ... def transform(pic, text): pic = img_transform(pic) text = tokenize(text) return pic, text dataset = GoogleConceptualCaption(split, transform) loader = dataset.loader(batch_size=64, num_workers=4) for pic, text in loader: # training model """ def __init__(self, split, transform=None, check_data=True, miniset=False): super().__init__() assert split in ["train", "val"] data_dir = get_data_dir() if miniset: data_dir = data_dir / "mini" ffr_file = data_dir / "googlecc" / split self.reader = FileReader(ffr_file, check_data=check_data) self.transform = transform def __len__(self): return self.reader.n def __getitem__(self, indices): data = self.reader.read(indices) samples = [] for bytes_ in data: sample = pickle.loads(bytes_) image = sample["image_bytes"] text = sample["caption"] sample = (image, text) if self.transform: sample = self.transform(*sample) samples.append(sample) return samples