Shortcuts

Source code for hfai.datasets.era5

import pickle
from ffrecord import FileReader
import torch

from .base import BaseDataset, register_dataset, get_data_dir

"""
Expected file organization:

    [data_dir]
        scaler.pkl
        train.ffr
            meta.pkl
            PART_00000.ffr
            PART_00001.ffr
            ...
        val.ffr
            meta.pkl
            PART_00000.ffr
            PART_00001.ffr
            ...
"""


class StandardScaler:
    def __init__(self):
        self.mean = 0.0
        self.std = 1.0

    def load(self, scaler_dir):
        with open(scaler_dir, "rb") as f:
            pkl = pickle.load(f)
            self.mean = pkl["mean"]
            self.std = pkl["std"]

    def inverse_transform(self, data):
        mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
        std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
        return (data * std) + mean


[docs]@register_dataset class ERA5(BaseDataset): """ 这是一个0.25°高分辨率的全球基础气象指标预报数据集 该数据集由欧洲中期天气预报中心(ECMWF)所构建并开源,从1979到2021年每6小时的全球气象数据。更多信息参考:https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5 数据包含20个气象指标和1个降水指标,按顺序依次是:``u10``, ``v10``, ``t2m``, ``z@1000``, ``z@50``, ``z@500``, ``z@850``, ``msl``, ``r@500``, ``r@850``, ``sp``, ``t@500``, ``t@850``, ``tcwv``, ``u@1000``, ``u@500``, ``u@850``, ``v@1000``, ``v@500``, ``v@850`` 和降水指标 ``tp`` Args: split (str): 数据集划分形式,包括:训练集(``train``)、验证集(``val``) mode (str): 具体的训练模式,包括:``pretrain``、``finetune``、``precipitation`` check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``) miniset (bool): 是否使用 mini 集合(默认为 ``False``) Returns: xt, xt1, xt2, pt1 (np.ndarray, np.ndarray, np.ndarray, np.ndarray): 返回的每个样本是一个四元组,包括t时刻、t+1时刻和t+2时刻的气象指标数据,t+1时刻的降水数据 Examples: .. code-block:: python from hfai.datasets import ERA5 dataset = ERA5(split, mode='pretrain', check_data=True) loader = dataset.loader(batch_size=3, num_workers=4) for xt, xt1 in loader: # pretrain model .. code-block:: python dataset = ERA5(split, mode='finetune', check_data=True) loader = dataset.loader(batch_size=1, num_workers=4) for xt, xt1, xt2 in loader: # finetune model .. code-block:: python dataset = ERA5(split, mode='precipitation', check_data=True) loader = dataset.loader(batch_size=4, num_workers=4) for xt, pt1 in loader: # training precipitation model """ def __init__(self, split: str, mode: str = 'pretrain', check_data: bool = True, miniset: bool = False) -> None: super(ERA5, self).__init__() assert mode in ["pretrain", "finetune", "precipitation"] data_dir = get_data_dir() if miniset: data_dir = data_dir / "mini" self.data_dir = data_dir / "ERA5" assert split in ["train", "val"] self.split = split self.mode = mode self.fname = str(self.data_dir / f"{split}.ffr") self.reader = FileReader(self.fname, check_data) self.scaler = StandardScaler() self.scaler.load(str(self.data_dir / "scaler.pkl")) def __len__(self): return self.reader.n def __getitem__(self, indices): seqs_bytes = self.reader.read(indices) samples = [] for i, bytes_ in enumerate(seqs_bytes): xt, xt1, xt2, pt1 = pickle.loads(bytes_) if self.mode == 'pretrain': samples.append((xt[1:], xt1[1:])) elif self.mode == 'finetune': samples.append((xt[1:], xt1[1:], xt2[1:])) elif self.mode == 'precipitation': samples.append((xt[1:], pt1[1:])) else: raise KeyError(f'mode ({self.mode}) is invalid.') return samples
[docs] def get_scaler(self): """获取WeatherBench数据的统计特征信息。 Returns: 数据分布统计对象,指标顺序如主类描述,包含:指标数据的均值(``mean``),指标数据的方差(``std``) """ return self.scaler