Source code for hfai.datasets.googlecc
import pickle
from ffrecord import FileReader
from .base import (
BaseDataset,
get_data_dir,
register_dataset,
)
"""
Expected file organization:
[data_dir]
train/
Train_GCC-training_output_washed_mp_000.ffr
...
Train_GCC-training_output_washed_mp_095.ffr
val/
Validation_GCC-1.1.0-Validation_output_washed_mp_0.ffr
"""
[docs]@register_dataset
class GoogleConceptualCaption(BaseDataset):
"""
这是一个用于多模态训练的数据集
该数据集是一个子数据集,从 3318333 个 “图片-字幕” 对中随机采样了 2850879 个。更多信息参考:https://ai.google.com/research/ConceptualCaptions/
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片和文本进行 transfrom,接受一张图片和一段文本作为输入,输出 transform 之后的结果
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
pic, text (PIL.Image.Image, str): 返回的每个样本是一个元组,包含一个RGB格式的图片,一段字幕文本
Examples:
.. code-block:: python
from hfai.datasets import GoogleConceptualCaption
from torchvision import transforms
img_transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
tokenize = ...
def transform(pic, text):
pic = img_transform(pic)
text = tokenize(text)
return pic, text
dataset = GoogleConceptualCaption(split, transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for pic, text in loader:
# training model
"""
def __init__(self, split, transform=None, check_data=True, miniset=False):
super().__init__()
assert split in ["train", "val"]
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
ffr_file = data_dir / "googlecc" / split
self.reader = FileReader(ffr_file, check_data=check_data)
self.transform = transform
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
data = self.reader.read(indices)
samples = []
for bytes_ in data:
sample = pickle.loads(bytes_)
image = sample["image_bytes"]
text = sample["caption"]
sample = (image, text)
if self.transform:
sample = self.transform(*sample)
samples.append(sample)
return samples