Source code for hfai.datasets.era5
import pickle
from ffrecord import FileReader
import torch
from .base import BaseDataset, register_dataset, get_data_dir
"""
Expected file organization:
[data_dir]
scaler.pkl
train.ffr
meta.pkl
PART_00000.ffr
PART_00001.ffr
...
val.ffr
meta.pkl
PART_00000.ffr
PART_00001.ffr
...
"""
class StandardScaler:
def __init__(self):
self.mean = 0.0
self.std = 1.0
def load(self, scaler_dir):
with open(scaler_dir, "rb") as f:
pkl = pickle.load(f)
self.mean = pkl["mean"]
self.std = pkl["std"]
def inverse_transform(self, data):
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
return (data * std) + mean
[docs]@register_dataset
class ERA5(BaseDataset):
"""
这是一个0.25°高分辨率的全球基础气象指标预报数据集
该数据集由欧洲中期天气预报中心(ECMWF)所构建并开源,从1979到2021年每6小时的全球气象数据。更多信息参考:https://www.ecmwf.int/en/forecasts/datasets/reanalysis-datasets/era5
数据包含20个气象指标和1个降水指标,按顺序依次是:``u10``, ``v10``, ``t2m``, ``z@1000``, ``z@50``, ``z@500``, ``z@850``, ``msl``, ``r@500``, ``r@850``, ``sp``,
``t@500``, ``t@850``, ``tcwv``, ``u@1000``, ``u@500``, ``u@850``, ``v@1000``, ``v@500``, ``v@850`` 和降水指标 ``tp``
Args:
split (str): 数据集划分形式,包括:训练集(``train``)、验证集(``val``)
mode (str): 具体的训练模式,包括:``pretrain``、``finetune``、``precipitation``
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
miniset (bool): 是否使用 mini 集合(默认为 ``False``)
Returns:
xt, xt1, xt2, pt1 (np.ndarray, np.ndarray, np.ndarray, np.ndarray): 返回的每个样本是一个四元组,包括t时刻、t+1时刻和t+2时刻的气象指标数据,t+1时刻的降水数据
Examples:
.. code-block:: python
from hfai.datasets import ERA5
dataset = ERA5(split, mode='pretrain', check_data=True)
loader = dataset.loader(batch_size=3, num_workers=4)
for xt, xt1 in loader:
# pretrain model
.. code-block:: python
dataset = ERA5(split, mode='finetune', check_data=True)
loader = dataset.loader(batch_size=1, num_workers=4)
for xt, xt1, xt2 in loader:
# finetune model
.. code-block:: python
dataset = ERA5(split, mode='precipitation', check_data=True)
loader = dataset.loader(batch_size=4, num_workers=4)
for xt, pt1 in loader:
# training precipitation model
"""
def __init__(self, split: str, mode: str = 'pretrain', check_data: bool = True, miniset: bool = False) -> None:
super(ERA5, self).__init__()
assert mode in ["pretrain", "finetune", "precipitation"]
data_dir = get_data_dir()
if miniset:
data_dir = data_dir / "mini"
self.data_dir = data_dir / "ERA5"
assert split in ["train", "val"]
self.split = split
self.mode = mode
self.fname = str(self.data_dir / f"{split}.ffr")
self.reader = FileReader(self.fname, check_data)
self.scaler = StandardScaler()
self.scaler.load(str(self.data_dir / "scaler.pkl"))
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
seqs_bytes = self.reader.read(indices)
samples = []
for i, bytes_ in enumerate(seqs_bytes):
xt, xt1, xt2, pt1 = pickle.loads(bytes_)
if self.mode == 'pretrain':
samples.append((xt[1:], xt1[1:]))
elif self.mode == 'finetune':
samples.append((xt[1:], xt1[1:], xt2[1:]))
elif self.mode == 'precipitation':
samples.append((xt[1:], pt1[1:]))
else:
raise KeyError(f'mode ({self.mode}) is invalid.')
return samples
[docs] def get_scaler(self):
"""获取WeatherBench数据的统计特征信息。
Returns:
数据分布统计对象,指标顺序如主类描述,包含:指标数据的均值(``mean``),指标数据的方差(``std``)
"""
return self.scaler