Source code for hfai.datasets.ogb
import torch
import numpy as np
from .base import BaseDataset, register_dataset, get_data_dir
"""
Expected file organization:
[data_dir]
ogbg_code2.pt
ogbg_molhiv.pt
ogbg_molpcba.pt
ogbg_ppa.pt
ogbl_biokg.pt
ogbl_citation2.pt
ogbl_collab.pt
ogbl_ddi.pt
ogbl_ppa.pt
ogbl_wikikg2.pt
ogbn_arxiv.pt
ogbn_mag.pt
ogbn_papers100M.pt
ogbn_products.pt
ogbn_proteins.pt
"""
[docs]@register_dataset
class OGB(BaseDataset):
"""
这是一个用于图机器学习的基准数据集
Open Graph Benchmark (OGB) 是一组真实的、大规模的、多样化的基准数据集,用于在图上进行机器学习。更多信息参考:https://ogb.stanford.edu/
该数据集收纳了如下图数据:
- 节点属性预测(Node Property Prediction): ``ogbn-products``,``ogbn-proteins``,``ogbn-arxiv``,``ogbn-papers100M``,``ogbn-mag``
- 连接预测(Link Property Prediction): ``ogbl-biokg``,``ogbl-citation2``,``ogbl-collab``,``ogbl-ddi``,``ogbl-ddi``,``ogbl-ppa``,``ogbl-wikikg2``
- 图属性预测(Graph Property Prediction): ``ogbg-code2``,``ogbg-molhiv``,``ogbg-molpcba``,``ogbg-ppa``
- 大规模图(Large-Scale Chanllenge): ``ogblsc-wikikg2``
Args:
data_name (str): 数据的名字
Returns:
data, split (torch_geometric.data.data.Data, dict): 返回的是 torch_geometric 格式的图数据,和“训练、验证、测试”样本分割索引
Examples:
加载 ``ogbn-`` 系列数据集
>>> from hfai.datasets import OGB
>>> dataset = OGB(data_name='ogbn-proteins')
>>> dataset.get_data()
Data(num_nodes=132534, edge_index=[2, 79122504], edge_attr=[79122504, 8], node_species=[132534, 1], y=[132534, 112])
加载 ``ogbl-`` 系列数据集
>>> from hfai.datasets import OGB
>>> dataset = OGB(data_name='ogbl-ppa')
>>> dataset.get_data()
Data(num_nodes=576289, edge_index=[2, 42463862], x=[576289, 58])
加载 ``ogbg-`` 系列数据集
>>> from hfai.datasets import OGB
>>> dataset = OGB(data_name='ogbg-molhiv')
>>> dataset.get_data()
[Data(edge_index=[2, 40], edge_attr=[40, 3], x=[19, 9], y=[1, 1], num_nodes=19),
Data(edge_index=[2, 88], edge_attr=[88, 3], x=[39, 9], y=[1, 1], num_nodes=39),
...
加载 ``ogblsc-`` 系列数据集
>>> from hfai.datasets import OGB
>>> dataset = OGB(data_name='ogblsc-wikikg2')
>>> dataset.get_data()
{'num_entities': 91230610, 'num_relations': 1387, 'num_feat_dims': 768, 'entity_feat': ..., 'realtion_feat': ..., 'train_hrt': ..., 'val_hr': ..., 'val_t': ...}
"""
def __init__(self, data_name: str) -> None:
super(OGB, self).__init__()
assert data_name in ["ogbg-code2", "ogbg-molhiv", "ogbg-molpcba", "ogbg-ppa",
"ogbl-biokg", "ogbl-citation2", "ogbl-collab", "ogbl-ddi", "ogbl-ppa", "ogbl-wikikg2",
"ogbn-arxiv", "ogbn-mag", "ogbn-papers100M", "ogbn-products", "ogbn-proteins",
"ogblsc-wikikg2"]
data_dir = get_data_dir()
self.data_dir = data_dir / "OGB"
self.reader = None
tags = data_name.split('-')
if tags[0] == 'ogbg':
self.mode = 'graph'
elif tags[0] == 'ogbl':
self.mode = 'link'
elif tags[0] == 'ogbn':
self.mode = 'node'
elif tags[0] == 'ogblsc':
self.mode = 'large'
else:
raise ValueError(f'{data_name} is invalid')
try:
if self.mode == 'large':
self.data = {
'num_entities': 91230610,
'num_relations': 1387,
'num_feat_dims': 768,
'entity_feat': np.load(f"{self.data_dir}/{'_'.join(tags)}/entity_feat.npy"),
'realtion_feat': np.load(f"{self.data_dir}/{'_'.join(tags)}/relation_feat.npy"),
'train_hrt': np.load(f"{self.data_dir}/{'_'.join(tags)}/train_hrt.npy"),
'val_hr': np.load(f"{self.data_dir}/{'_'.join(tags)}/val_hr.npy"),
'val_t': np.load(f"{self.data_dir}/{'_'.join(tags)}/val_t.npy")
}
self.split = {}
elif self.mode == 'graph':
ckpt = torch.load(self.data_dir / f"{'_'.join(tags)}.pt", map_location="cpu")
self.data = (ckpt['data'], ckpt['slices'])
self.split = ckpt['split']
else:
ckpt = torch.load(self.data_dir / f"{'_'.join(tags)}.pt", map_location="cpu")
self.data = ckpt['data']
self.split = ckpt['split']
except Exception as e:
print(f"OGB cannot use | Error: {e}")
self.data = None
self.split = {}
self.mode = 'none'
def __len__(self):
raise ValueError("OGB dataset 不支持该方法")
def __getitem__(self, indices):
raise ValueError("OGB dataset 不支持索引")
[docs] def get_data(self):
"""获取OGB图数据。
Returns:
节点/链接预测任务下返回torch_geometric格式的图数据,图属性预测任务下是返回元组,大规模图任务下返回具有ndarray格式数据的字典
"""
return self.data
[docs] def get_split(self):
"""获取OGB图数据的分割索引。
Returns:
返回字典,包含“训练、验证、测试”样本分割的索引。大规模图任务下该方法不可用
"""
if self.mode == "large":
raise NotImplementedError("大规模图任务数据集不能使用数据分割方法")
return self.split
[docs] def loader(self, *args, **kwargs):
"""OGB数据不支持直接Dataloader。
"""
raise RuntimeError("OGB dataset 不支持此方法")