Source code for hfai.datasets.cifar
from typing import Callable, Optional
import torch
import torchvision
from .base import (
BaseDataset,
get_data_dir,
register_dataset,
)
[docs]@register_dataset
class CIFAR10(torchvision.datasets.CIFAR10, BaseDataset):
"""
这是一个用于识别普适物体的小型数据集
该数据集一共包含 10 个类别的 RGB 彩色图片,每个图片的尺寸为 32 × 32 ,每个类别有 600 个图像,数据集中一共有 500 张训练图片和 100 张测试图片。更多信息参考官网:https://www.cs.toronto.edu/~kriz/cifar.html
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片
target_transform (Callable): 对 target 进行 transfrom,接受一个 target 作为输入,输出 transform 之后的 target
Returns:
image, target (PIL.Image.Image, int): 返回的每条样本是一个元组,包含一个RGB格式的图片,及其对应的目标标签
Examples:
.. code-block:: python
from hfai.datasets import CIFAR10
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
dataset = CIFAR10('train', transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for image, target in loader:
# training model
NOTE:
使用的时候所有数据会直接加载进内存,大小大约为 178 MiB。``CIFAR10`` 和 ``CIFAR100`` 的 ``loader()`` 方法返回的是一个 ``torch.utils.data.DataLoader`` ,而不是 ``ffrecord.torch.DataLoader`` 。
"""
def __init__(
self, split: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
) -> None:
assert split in ["train", "test"]
super().__init__(str(get_data_dir() / "CIFAR"), split == "train", transform, target_transform)
def loader(self, *args, **kwargs) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(self, *args, **kwargs)
[docs]@register_dataset
class CIFAR100(torchvision.datasets.CIFAR100, BaseDataset):
"""
这是一个用于识别普适物体的大型数据集
该数据集一共包含 100 个类别的 RGB 彩色图片,每个图片的尺寸为 32 × 32 ,每个类别有 600 个图像,数据集中一共有 500 张训练图片和 100 张测试图片。更多信息参考官网:https://www.cs.toronto.edu/~kriz/cifar.html
Args:
split (str): 数据集划分形式,包括:训练集(``train``)或者验证集(``val``)
transform (Callable): transform 函数,对图片进行 transfrom,接受一张图片作为输入,输出 transform 之后的图片
target_transform (Callable): 对 target 进行 transfrom,接受一个 target 作为输入,输出 transform 之后的 target
Returns:
image, target (PIL.Image.Image, int): 返回的每条样本是一个元组,包含一个 RGB 格式的图片,及其对应的目标标签
Examples:
.. code-block:: python
from hfai.datasets import CIFAR100
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
dataset = CIFAR100('train', transform)
loader = dataset.loader(batch_size=64, num_workers=4)
for image, target in loader:
# training model
NOTE:
使用的时候所有数据会直接加载进内存,大小大约为 178 MiB。``CIFAR10`` 和 ``CIFAR100`` 的 ``loader()`` 方法返回的是一个 ``torch.utils.data.DataLoader`` ,而不是 ``ffrecord.torch.DataLoader`` 。
"""
def __init__(
self, split: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None
) -> None:
assert split in ["train", "test"]
super().__init__(str(get_data_dir() / "CIFAR"), split == "train", transform, target_transform)
def loader(self, *args, **kwargs) -> torch.utils.data.DataLoader:
return torch.utils.data.DataLoader(self, *args, **kwargs)