Source code for hfai.datasets.alphafold
from typing import Callable, Optional
import pickle
from ffrecord import FileReader
from hfai.datasets.base import (
BaseDataset,
get_data_dir,
register_dataset
)
"""
Expected file organization:
[data_dir]
train
ffrdata
PART_00000.ffr
PART_00001.ffr
PART_00002.ffr
...
pdb_mmcif
mmcifs.ffr
mmcifs_index.pk
mmcif_cache_all.json
"""
[docs]@register_dataset
class AlphafoldData(BaseDataset):
"""
这是一个用于Alphafold训练的蛋白质预测数据集
数据集中包含了131,291条蛋白质序列,一般整体作为训练集使用。其中pdb_mmcif和mmcif_cache_all.json数据在训练中不经dataset直接读取使用。
Args:
transform (Callable): transform函数,对蛋白质序列和alignment数据做处理
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
Returns:
pdb_code, mmcif_string, bfd_hits, mgnify_hits, pdb70_hits, uniref90_hits: 返回一个元组,里面包含一个蛋白质的序列信息以及对应的alignment数据。
Examples:
.. code-block:: python
from hfai.datasets import AlphafoldData
from torchvision import transforms
# 使用alphafold-optimized仓库时
transform = AlphafoldDataTransform()
dataset = AlphafoldData(transform)
loader = dataset.loader(batch_size=1, num_workers=8)
for data in loader:
# training model
"""
def __init__(
self,
transform: Optional[Callable] =None,
check_data: bool = True,
) -> None:
super(AlphafoldData, self).__init__()
data_dir = get_data_dir()
ffr_file = data_dir / "Alphafold" / "train" / "ffrdata"
self.reader = FileReader(ffr_file, check_data=check_data)
self.transform = transform
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
data = self.reader.read(indices)
samples = []
for bytes_ in data:
mmcifdata = pickle.loads(bytes_)
if self.transform:
mmcifdata = self.transform(*mmcifdata)
samples.append(mmcifdata)
return samples