Shortcuts

Source code for hfai.datasets.alphafold

from typing import Callable, Optional
import pickle
from ffrecord import FileReader
from hfai.datasets.base import (
    BaseDataset,
    get_data_dir,
    register_dataset
)


"""
Expected file organization:

    [data_dir]
        train
            ffrdata
                PART_00000.ffr
                PART_00001.ffr
                PART_00002.ffr
                ...
        pdb_mmcif
            mmcifs.ffr
            mmcifs_index.pk
        mmcif_cache_all.json
"""

[docs]@register_dataset class AlphafoldData(BaseDataset): """ 这是一个用于Alphafold训练的蛋白质预测数据集 数据集中包含了131,291条蛋白质序列,一般整体作为训练集使用。其中pdb_mmcif和mmcif_cache_all.json数据在训练中不经dataset直接读取使用。 Args: transform (Callable): transform函数,对蛋白质序列和alignment数据做处理 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) Returns: pdb_code, mmcif_string, bfd_hits, mgnify_hits, pdb70_hits, uniref90_hits: 返回一个元组,里面包含一个蛋白质的序列信息以及对应的alignment数据。 Examples: .. code-block:: python from hfai.datasets import AlphafoldData from torchvision import transforms # 使用alphafold-optimized仓库时 transform = AlphafoldDataTransform() dataset = AlphafoldData(transform) loader = dataset.loader(batch_size=1, num_workers=8) for data in loader: # training model """ def __init__( self, transform: Optional[Callable] =None, check_data: bool = True, ) -> None: super(AlphafoldData, self).__init__() data_dir = get_data_dir() ffr_file = data_dir / "Alphafold" / "train" / "ffrdata" self.reader = FileReader(ffr_file, check_data=check_data) self.transform = transform def __len__(self): return self.reader.n def __getitem__(self, indices): data = self.reader.read(indices) samples = [] for bytes_ in data: mmcifdata = pickle.loads(bytes_) if self.transform: mmcifdata = self.transform(*mmcifdata) samples.append(mmcifdata) return samples