Shortcuts

Source code for hfai.datasets.base

from typing import Union
import os
from typing import Union
from pathlib import Path
from ffrecord.torch import Dataset, DataLoader

DATA_DIR = None
DEFAULT_DATA_DIR = Path("/public_dataset/1/ffdataset")


[docs]def set_data_dir(path: Union[str, os.PathLike]) -> None: """ 设置数据集存放的主目录 我们会优先使用通过 ``set_data_dir`` 设置的路径,如果没有则会去使用环境变量 ``HFAI_DATASETS_DIR`` 的值。 两者都没有设置的情况下,使用默认目录 ``/public_dataset/1/ffdataset``。 Args: path (str, os.PathLike): 数据集存放的主目录 Examples: >>> hfai.datasets.set_data_dir("/your/data/dir") >>> hfai.datasets.get_data_dir() PosixPath('/your/data/dir') """ global DATA_DIR DATA_DIR = Path(path).absolute()
[docs]def get_data_dir() -> Path: """ 返回当前数据集主目录 Returns: data_dir (Path): 当前数据集主目录 Examples: >>> hfai.datasets.set_data_dir("/your/data/dir") >>> hfai.datasets.get_data_dir() PosixPath('/your/data/dir') """ global DATA_DIR # 1. set_data_dir() 设置的路径 if DATA_DIR is not None: return DATA_DIR.absolute() # 2. 环境变量 HFAI_DATASETS_DIR 指定的路径 env = os.getenv("HFAI_DATASETS_DIR") if env is not None: return Path(env) # 3. 默认路径 return DEFAULT_DATA_DIR
[docs]class BaseDataset(Dataset): """hfai.dataset 基类 """ _repr_indent = 4 def __init__(self): """ """ pass def __len__(self): raise NotImplementedError def __repr__(self) -> str: head = "hfai.datasets." + self.__class__.__name__ body = [f"Number of datapoints: {self.__len__()}"] if hasattr(self, "split") and self.split is not None: body.append(f"Split: {self.split}") lines = [head] + [" " * self._repr_indent + line for line in body] return "\n".join(lines)
[docs] def loader(self, *args, **kwargs) -> DataLoader: """ 获取数据集的Dataloader 参数与 `PyTorch的Dataloader <https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader>`_ 保持一致 Returns: 数据集的Dataloader """ return DataLoader(self, *args, **kwargs)
DATASETS = {} def register_dataset(cls): if not issubclass(cls, BaseDataset): raise TypeError("Can only register classes inherited from BaseDataset") DATASETS[cls.__name__] = cls return cls