Source code for hfai.datasets.weatherBench
import pickle
import numpy as np
from ffrecord import FileReader
import torch
from .base import BaseDataset, register_dataset, get_data_dir
"""
Expected file organization:
[data_dir]
kernel.pkl
scaler.pkl
train.ffr
val.ffr
test.ffr
"""
class StandardScaler:
def __init__(self):
self.mean = 0.0
self.std = 1.0
def load(self, scaler_dir):
with open(scaler_dir, "rb") as f:
pkl = pickle.load(f)
self.mean = pkl["mean"]
self.std = pkl["std"]
def inverse_transform(self, data):
mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
return (data * std) + mean
[docs]@register_dataset
class WeatherBench(BaseDataset):
"""
这是一个用于天气预报的基准数据集
该数据集基于欧洲气象局标准数据集ERA5采样所得,从1979到2019年每小时的全球气象数据。更多信息参考:https://github.com/pangeo-data/WeatherBench
Args:
data_name (str): 具体的气象预报项,包括:``2m_temperature``、``relative_humidity``、``component_of_wind``、``total_cloud_cover``
split (str): 数据集划分形式,包括:训练集(``train``)、验证集(``val``)或者测试集(``test``)
include_context (bool): 是否包含气象上下文信息作为输入
check_data (bool): 是否对每一条样本检验校验和(默认为 ``True``)
Returns:
seq_x, seq_y (np.ndarray, np.ndarray): 返回的每个样本是一个二元组,包括历史指标序列数据,未来指标序列
Examples:
.. code-block:: python
from hfai.datasets import WeatherBench
dataset = WeatherBench(data_name, split, include_context, check_data)
loader = dataset.loader(batch_size=64, num_workers=4)
for seq_x, seq_y in loader:
# training model
"""
def __init__(self, data_name: str, split: str, include_context: bool = False, check_data: bool = True) -> None:
super(WeatherBench, self).__init__()
assert data_name in ["2m_temperature", "relative_humidity", "component_of_wind", "total_cloud_cover"]
self.data_dir = get_data_dir() / "WeatherBench" / data_name
assert split in ["train", "val", "test"]
self.split = split
self.include_context = include_context
self.fname = str(self.data_dir / f"{split}.ffr")
self.reader = FileReader(self.fname, check_data)
self.scaler = StandardScaler()
self.scaler.load(str(self.data_dir / "scaler.pkl"))
with open(str(self.data_dir / "kernel.pkl"), "rb") as f:
self.kernel_info = pickle.load(f)
def __len__(self):
return self.reader.n
def __getitem__(self, indices):
seqs_bytes = self.reader.read(indices)
samples = []
for i, bytes_ in enumerate(seqs_bytes):
x, y, context = pickle.loads(bytes_)
if self.include_context:
x = np.concatenate([x, context], axis=-1)
samples.append((x, y))
return samples
[docs] def get_scaler(self):
"""获取WeatherBench数据的统计特征信息。
Returns:
数据分布统计对象,包含:指标数据的均值(``mean``),指标数据的方差(``std``)
"""
return self.scaler
[docs] def get_kernel(self):
"""获取WeatherBench数据的地理核信息。
Returns:
kernel (tuple): 核信息是一个四元组,包含:稀疏矩阵(``sparse_idx``),局部卷积核输入(``MLP_inputs``),测地线(``geodesic``),切面角度比(``angle_ratio``)
"""
kernel_info = self.kernel_info
return kernel_info["sparse_idx"], kernel_info["MLP_inputs"], kernel_info["geodesic"], kernel_info["angle_ratio"]