Shortcuts

Source code for hfai.datasets.cifar

from typing import Callable, Optional

import torch
import torchvision
from .base import (
    BaseDataset,
    get_data_dir,
    register_dataset,
)


[docs]@register_dataset class CIFAR10(torchvision.datasets.CIFAR10, BaseDataset): """ 这是一个用于识别普适物体的小型数据集 该数据集一共包含 10 个类别的 RGB 彩色图片,每个图片的尺寸为 32 × 32 ,每个类别有 600 个图像,数据集中一共有 500 张训练图片和 100 张测试图片。更多信息参考官网:https://www.cs.toronto.edu/~kriz/cifar.html Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片 target_transform (Callable): 对 target 进行 transfrom,接受一个 target 作为输入,输出 transform 之后的 target Returns: image, target (PIL.Image.Image, int): 返回的每条样本是一个元组,包含一个RGB格式的图片,及其对应的目标标签 Examples: .. code-block:: python from hfai.datasets import CIFAR10 from torchvision import transforms transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) dataset = CIFAR10('train', transform) loader = dataset.loader(batch_size=64, num_workers=4) for image, target in loader: # training model NOTE: 使用的时候所有数据会直接加载进内存,大小大约为 178 MiB。``CIFAR10`` 和 ``CIFAR100`` 的 ``loader()`` 方法返回的是一个 ``torch.utils.data.DataLoader`` ,而不是 ``ffrecord.torch.DataLoader`` 。 """ def __init__( self, split: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: assert split in ["train", "test"] super().__init__(str(get_data_dir() / "CIFAR"), split == "train", transform, target_transform) def loader(self, *args, **kwargs) -> torch.utils.data.DataLoader: return torch.utils.data.DataLoader(self, *args, **kwargs)
[docs]@register_dataset class CIFAR100(torchvision.datasets.CIFAR100, BaseDataset): """ 这是一个用于识别普适物体的大型数据集 该数据集一共包含 100 个类别的 RGB 彩色图片,每个图片的尺寸为 32 × 32 ,每个类别有 600 个图像,数据集中一共有 500 张训练图片和 100 张测试图片。更多信息参考官网:https://www.cs.toronto.edu/~kriz/cifar.html Args: split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``) transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片 target_transform (Callable): 对 target 进行 transfrom,接受一个 target 作为输入,输出 transform 之后的 target Returns: image, target (PIL.Image.Image, int): 返回的每条样本是一个元组,包含一个 RGB 格式的图片,及其对应的目标标签 Examples: .. code-block:: python from hfai.datasets import CIFAR100 from torchvision import transforms transform = transforms.Compose([ transforms.Resize(224), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) dataset = CIFAR100('train', transform) loader = dataset.loader(batch_size=64, num_workers=4) for image, target in loader: # training model NOTE: 使用的时候所有数据会直接加载进内存,大小大约为 178 MiB。``CIFAR10`` 和 ``CIFAR100`` 的 ``loader()`` 方法返回的是一个 ``torch.utils.data.DataLoader`` ,而不是 ``ffrecord.torch.DataLoader`` 。 """ def __init__( self, split: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None ) -> None: assert split in ["train", "test"] super().__init__(str(get_data_dir() / "CIFAR"), split == "train", transform, target_transform) def loader(self, *args, **kwargs) -> torch.utils.data.DataLoader: return torch.utils.data.DataLoader(self, *args, **kwargs)