Shortcuts

Source code for hfai.datasets

from tabulate import tabulate
import importlib
from pathlib import Path
import inspect
from typing import Union
import os
import sys

from .base import DATASETS, BaseDataset, set_data_dir, get_data_dir

from .imagenet import ImageNet
from .imagenet21k import ImageNet21K
from .imagenet1k import ImageNet1K
from .coco import CocoCaption, CocoDetection, CocoKeypoint, CocoPanoptic
from .ade20k import ADE20k
from .isc2021 import ISC2021
from .ljspeech import LJSpeech
from .cifar import CIFAR10, CIFAR100
from .kitti import KITTIObject2D
from .ltsf import LTSF
from .weatherBench import WeatherBench
from .googlecc import GoogleConceptualCaption
from .clue import CLUEForMLM, CLUEForCLS
from .nuscenes import NuScenes
from .era5 import ERA5
from .ogb import OGB
from .alphafold import AlphafoldData

from .dataset_downloader import DatasetDownloader
from .download_const import datasets2dirs

list_ = list


def list() -> None:
    """
    列出并打印所有可用的数据集

    Examples::

        >>> import hfai.datasets
        >>> hfai.datasets.list()

    """
    prefix = "hfai.datasets."
    tab = [(k, prefix + k) for k in DATASETS]
    tab.sort()
    print(tabulate(tab, headers=["Name", "Class"], tablefmt="grid"))


def create(name: str, *args, **kwargs) -> BaseDataset:
    """
    创建一个数据集实例

    Args:
        name (str): 数据集的名字
        *args, **kwargs: 传给数据集构造函数的参数

    Returns:
        dataset (BaseDataset): 创建的数据集实例

    Examples::

        >>> hfai.datasets.create("ImageNet", split="train", transform=None)

    """
    if name not in DATASETS:
        raise ValueError(
            f"Unknown dataset '{name}'. You can use hfai.datasets.list() " "to get a list of all available datasets."
        )

    return DATASETS[name](*args, **kwargs)


[docs]def download(name: str, miniset: bool = False, workers: int = 4) -> bool: """ 下载数据集 Args: name (str): 需要下载的数据集的名称,比如 CocoDetection miniset (bool): 是否使用 mini 集(默认为 ``False``) workers (int): 进程数(默认为 4) Returns: success (bool): 是否下载成功 Examples:: >>> hfai.datasets.set_data_dir('/your/data/dir') >>> hfai.datasets.download("ImageNet", miniset=True) """ datasets_names = list_(DATASETS.keys()) assert name in datasets_names, \ f"{name} does not in hfai.datasets. Current supported datasets are {datasets_names}." if miniset: mini_datasets_names = [k for k, v in DATASETS.items() if 'miniset' in v.__init__.__code__.co_varnames] assert name in mini_datasets_names, \ f"{name} dataset does not support miniset. Only {mini_datasets_names} supported." dataset2dirs = datasets2dirs.get(name, {}) dir_name = dataset2dirs.get('dir_name', name) dirs_or_files = dataset2dirs.get('dirs_or_files', []) data_dir = get_data_dir() assert data_dir.exists(), f'data dir {data_dir} does not exist.' dataset_type = 'mini' if miniset else 'full' files = DatasetDownloader.list_bucket_files(dataset_type, dir_name, dirs_or_files) success = DatasetDownloader.download_files(files, data_dir, workers, dir_name, dataset_type) return success
def load(fname: Union[str, os.PathLike], name: str = None) -> None: """ 从 Python 文件动态地加载 Dataset Args: fname (str, os.PathLike): Python 文件路径 name (str): 需要加载的类的名字,如果是 None 会自动搜寻继承了 BaseDataset 的类;默认是 None Examples:: >>> hfai.datasets.load("/your/dataset/code/MyDataset.py") """ py_file = Path(fname).absolute() if not py_file.is_file(): raise RuntimeError(f"{py_file} is not a file") # 动态加载 python 模块 module_dir = py_file.parent sys.path.append(str(module_dir)) module = importlib.import_module(py_file.stem) this_mod = sys.modules[__name__] # 移除路径,避免污染环境 # TODO: 有可能会 failed,如果会依赖该路径下的其他模块 sys.path.pop(0) # 找到 custom datasets if name is not None: this_mod.__dict__[name] = module.__dict__[name] __all__.append(name) else: filter = lambda cls: inspect.isclass(cls) and issubclass(cls, BaseDataset) and cls != BaseDataset clsmembers = inspect.getmembers(module, filter) for n, _ in clsmembers: if n not in __all__: this_mod.__dict__[n] = module.__dict__[n] __all__.append(n) __all__ = [ "DATASETS", "list", "load", "get_data_dir", "set_data_dir", ] __all__ += list_(DATASETS.keys()) __version__ = r"{{version}}"