Shortcuts

Source code for hfai.datasets.ogb

import torch
import numpy as np

from .base import BaseDataset, register_dataset, get_data_dir

"""
Expected file organization:

    [data_dir]
        ogbg_code2.pt
        ogbg_molhiv.pt
        ogbg_molpcba.pt
        ogbg_ppa.pt
        ogbl_biokg.pt
        ogbl_citation2.pt
        ogbl_collab.pt
        ogbl_ddi.pt
        ogbl_ppa.pt
        ogbl_wikikg2.pt
        ogbn_arxiv.pt
        ogbn_mag.pt
        ogbn_papers100M.pt
        ogbn_products.pt
        ogbn_proteins.pt
"""


[docs]@register_dataset class OGB(BaseDataset): """ 这是一个用于图机器学习的基准数据集 Open Graph Benchmark (OGB) 是一组真实的、大规模的、多样化的基准数据集,用于在图上进行机器学习。更多信息参考:https://ogb.stanford.edu/ 该数据集收纳了如下图数据: - 节点属性预测(Node Property Prediction): ``ogbn-products``,``ogbn-proteins``,``ogbn-arxiv``,``ogbn-papers100M``,``ogbn-mag`` - 连接预测(Link Property Prediction): ``ogbl-biokg``,``ogbl-citation2``,``ogbl-collab``,``ogbl-ddi``,``ogbl-ddi``,``ogbl-ppa``,``ogbl-wikikg2`` - 图属性预测(Graph Property Prediction): ``ogbg-code2``,``ogbg-molhiv``,``ogbg-molpcba``,``ogbg-ppa`` - 大规模图(Large-Scale Chanllenge): ``ogblsc-wikikg2`` Args: data_name (str): 数据的名字 Returns: data, split (torch_geometric.data.data.Data, dict): 返回的是 torch_geometric 格式的图数据,和“训练、验证、测试”样本分割索引 Examples: 加载 ``ogbn-`` 系列数据集 >>> from hfai.datasets import OGB >>> dataset = OGB(data_name='ogbn-proteins') >>> dataset.get_data() Data(num_nodes=132534, edge_index=[2, 79122504], edge_attr=[79122504, 8], node_species=[132534, 1], y=[132534, 112]) 加载 ``ogbl-`` 系列数据集 >>> from hfai.datasets import OGB >>> dataset = OGB(data_name='ogbl-ppa') >>> dataset.get_data() Data(num_nodes=576289, edge_index=[2, 42463862], x=[576289, 58]) 加载 ``ogbg-`` 系列数据集 >>> from hfai.datasets import OGB >>> dataset = OGB(data_name='ogbg-molhiv') >>> dataset.get_data() [Data(edge_index=[2, 40], edge_attr=[40, 3], x=[19, 9], y=[1, 1], num_nodes=19), Data(edge_index=[2, 88], edge_attr=[88, 3], x=[39, 9], y=[1, 1], num_nodes=39), ... 加载 ``ogblsc-`` 系列数据集 >>> from hfai.datasets import OGB >>> dataset = OGB(data_name='ogblsc-wikikg2') >>> dataset.get_data() {'num_entities': 91230610, 'num_relations': 1387, 'num_feat_dims': 768, 'entity_feat': ..., 'realtion_feat': ..., 'train_hrt': ..., 'val_hr': ..., 'val_t': ...} """ def __init__(self, data_name: str) -> None: super(OGB, self).__init__() assert data_name in ["ogbg-code2", "ogbg-molhiv", "ogbg-molpcba", "ogbg-ppa", "ogbl-biokg", "ogbl-citation2", "ogbl-collab", "ogbl-ddi", "ogbl-ppa", "ogbl-wikikg2", "ogbn-arxiv", "ogbn-mag", "ogbn-papers100M", "ogbn-products", "ogbn-proteins", "ogblsc-wikikg2"] data_dir = get_data_dir() self.data_dir = data_dir / "OGB" self.reader = None tags = data_name.split('-') if tags[0] == 'ogbg': self.mode = 'graph' elif tags[0] == 'ogbl': self.mode = 'link' elif tags[0] == 'ogbn': self.mode = 'node' elif tags[0] == 'ogblsc': self.mode = 'large' else: raise ValueError(f'{data_name} is invalid') try: if self.mode == 'large': self.data = { 'num_entities': 91230610, 'num_relations': 1387, 'num_feat_dims': 768, 'entity_feat': np.load(f"{self.data_dir}/{'_'.join(tags)}/entity_feat.npy"), 'realtion_feat': np.load(f"{self.data_dir}/{'_'.join(tags)}/relation_feat.npy"), 'train_hrt': np.load(f"{self.data_dir}/{'_'.join(tags)}/train_hrt.npy"), 'val_hr': np.load(f"{self.data_dir}/{'_'.join(tags)}/val_hr.npy"), 'val_t': np.load(f"{self.data_dir}/{'_'.join(tags)}/val_t.npy") } self.split = {} elif self.mode == 'graph': ckpt = torch.load(self.data_dir / f"{'_'.join(tags)}.pt", map_location="cpu") self.data = (ckpt['data'], ckpt['slices']) self.split = ckpt['split'] else: ckpt = torch.load(self.data_dir / f"{'_'.join(tags)}.pt", map_location="cpu") self.data = ckpt['data'] self.split = ckpt['split'] except Exception as e: print(f"OGB cannot use | Error: {e}") self.data = None self.split = {} self.mode = 'none' def __len__(self): raise ValueError("OGB dataset 不支持该方法") def __getitem__(self, indices): raise ValueError("OGB dataset 不支持索引")
[docs] def get_data(self): """获取OGB图数据。 Returns: 节点/链接预测任务下返回torch_geometric格式的图数据,图属性预测任务下是返回元组,大规模图任务下返回具有ndarray格式数据的字典 """ return self.data
[docs] def get_split(self): """获取OGB图数据的分割索引。 Returns: 返回字典,包含“训练、验证、测试”样本分割的索引。大规模图任务下该方法不可用 """ if self.mode == "large": raise NotImplementedError("大规模图任务数据集不能使用数据分割方法") return self.split
[docs] def loader(self, *args, **kwargs): """OGB数据不支持直接Dataloader。 """ raise RuntimeError("OGB dataset 不支持此方法")