Shortcuts

Source code for hfai.datasets.weatherBench

import pickle
import numpy as np
from ffrecord import FileReader
import torch

from .base import BaseDataset, register_dataset, get_data_dir

"""
Expected file organization:

    [data_dir]
        kernel.pkl
        scaler.pkl
        train.ffr
        val.ffr
        test.ffr
"""


class StandardScaler:
    def __init__(self):
        self.mean = 0.0
        self.std = 1.0

    def load(self, scaler_dir):
        with open(scaler_dir, "rb") as f:
            pkl = pickle.load(f)
            self.mean = pkl["mean"]
            self.std = pkl["std"]

    def inverse_transform(self, data):
        mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
        std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
        return (data * std) + mean


[docs]@register_dataset class WeatherBench(BaseDataset): """ 这是一个用于天气预报的基准数据集 该数据集基于欧洲气象局标准数据集ERA5采样所得,从1979到2019年每小时的全球气象数据。更多信息参考:https://github.com/pangeo-data/WeatherBench Args: data_name (str): 具体的气象预报项,包括:``2m_temperature``、``relative_humidity``、``component_of_wind``、``total_cloud_cover`` split (str): 数据集划分形式,包括:训练集(``train``)、验证集(``val``)或者测试集(``test``) include_context (bool): 是否包含气象上下文信息作为输入 check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) Returns: seq_x, seq_y (np.ndarray, np.ndarray): 返回的每个样本是一个二元组,包括历史指标序列数据,未来指标序列 Examples: .. code-block:: python from hfai.datasets import WeatherBench dataset = WeatherBench(data_name, split, include_context, check_data) loader = dataset.loader(batch_size=64, num_workers=4) for seq_x, seq_y in loader: # training model """ def __init__(self, data_name: str, split: str, include_context: bool = False, check_data: bool = True) -> None: super(WeatherBench, self).__init__() assert data_name in ["2m_temperature", "relative_humidity", "component_of_wind", "total_cloud_cover"] self.data_dir = get_data_dir() / "WeatherBench" / data_name assert split in ["train", "val", "test"] self.split = split self.include_context = include_context self.fname = str(self.data_dir / f"{split}.ffr") self.reader = FileReader(self.fname, check_data) self.scaler = StandardScaler() self.scaler.load(str(self.data_dir / "scaler.pkl")) with open(str(self.data_dir / "kernel.pkl"), "rb") as f: self.kernel_info = pickle.load(f) def __len__(self): return self.reader.n def __getitem__(self, indices): seqs_bytes = self.reader.read(indices) samples = [] for i, bytes_ in enumerate(seqs_bytes): x, y, context = pickle.loads(bytes_) if self.include_context: x = np.concatenate([x, context], axis=-1) samples.append((x, y)) return samples
[docs] def get_scaler(self): """获取WeatherBench数据的统计特征信息。 Returns: 数据分布统计对象,包含:指标数据的均值(``mean``),指标数据的方差(``std``) """ return self.scaler
[docs] def get_kernel(self): """获取WeatherBench数据的地理核信息。 Returns: kernel (tuple): 核信息是一个四元组,包含:稀疏矩阵(``sparse_idx``),局部卷积核输入(``MLP_inputs``),测地线(``geodesic``),切面角度比(``angle_ratio``) """ kernel_info = self.kernel_info return kernel_info["sparse_idx"], kernel_info["MLP_inputs"], kernel_info["geodesic"], kernel_info["angle_ratio"]