Skip to content

lucidrains/metnet3-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

32 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MetNet-3 - Pytorch

Implementation of MetNet 3, SOTA neural weather model out of Google Deepmind, in Pytorch

The model architecture is pretty unremarkable. It is basically a U-net with a specific well performing vision transformer. The most interesting thing about the paper may end up being the loss scaling in section 4.3.2

Appreciation

Install

$ pip install metnet3-pytorch

Usage

import torch
from metnet3_pytorch import MetNet3

metnet3 = MetNet3(
    dim = 512,
    num_lead_times = 722,
    lead_time_embed_dim = 32,
    input_spatial_size = 624,
    attn_dim_head = 8,
    hrrr_channels = 617,
    input_2496_channels = 2 + 14 + 1 + 2 + 20,
    input_4996_channels = 16 + 1,
    precipitation_target_bins = dict(
        mrms_rate = 512,
        mrms_accumulation = 512,
    ),
    surface_target_bins = dict(
        omo_temperature = 256,
        omo_dew_point = 256,
        omo_wind_speed = 256,
        omo_wind_component_x = 256,
        omo_wind_component_y = 256,
        omo_wind_direction = 180
    ),
    hrrr_loss_weight = 10,
    hrrr_norm_strategy = 'sync_batchnorm',  # this would use a sync batchnorm to normalize the input hrrr and target, without having to precalculate the mean and variance of the hrrr dataset per channel
    hrrr_norm_statistics = None             # you can also also set `hrrr_norm_strategy = "precalculated"` and pass in the mean and variance as shape `(2, 617)` through this keyword argument
)

# inputs

lead_times = torch.randint(0, 722, (2,))
hrrr_input_2496 = torch.randn((2, 617, 624, 624))
hrrr_stale_state = torch.randn((2, 1, 624, 624))
input_2496 = torch.randn((2, 39, 624, 624))
input_4996 = torch.randn((2, 17, 624, 624))

# targets

precipitation_targets = dict(
    mrms_rate = torch.randint(0, 512, (2, 512, 512)),
    mrms_accumulation = torch.randint(0, 512, (2, 512, 512)),
)

surface_targets = dict(
    omo_temperature = torch.randint(0, 256, (2, 128, 128)),
    omo_dew_point = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_speed = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_x = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_component_y = torch.randint(0, 256, (2, 128, 128)),
    omo_wind_direction = torch.randint(0, 180, (2, 128, 128))
)

hrrr_target = torch.randn(2, 617, 128, 128)

total_loss, loss_breakdown = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
    precipitation_targets = precipitation_targets,
    surface_targets = surface_targets,
    hrrr_target = hrrr_target
)

total_loss.backward()

# after much training from above, you can predict as follows

metnet3.eval()

surface_preds, hrrr_pred, precipitation_preds = metnet3(
    lead_times = lead_times,
    hrrr_input_2496 = hrrr_input_2496,
    hrrr_stale_state = hrrr_stale_state,
    input_2496 = input_2496,
    input_4996 = input_4996,
)


# Dict[str, Tensor], Tensor, Dict[str, Tensor]

Todo

  • figure out all the cross entropy and MSE losses

  • auto-handle normalization across all the channels of the HRRR by tracking a running mean and variance of targets during training (using sync batchnorm as hack)

  • allow researcher to pass in their own normalization variables for HRRR

  • build all the inputs to spec, also make sure hrrr input is normalized, offer option to unnormalize hrrr predictions

  • make sure model can be easily saved and loaded, with different ways of handling hrrr norm

  • figure out the topological embedding, consult a neural weather researcher

Citations

@article{Andrychowicz2023DeepLF,
    title   = {Deep Learning for Day Forecasts from Sparse Observations},
    author  = {Marcin Andrychowicz and Lasse Espeholt and Di Li and Samier Merchant and Alexander Merose and Fred Zyda and Shreya Agrawal and Nal Kalchbrenner},
    journal = {ArXiv},
    year    = {2023},
    volume  = {abs/2306.06079},
    url     = {https://api.semanticscholar.org/CorpusID:259129311}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}