Example: Weather prediction using Neural ODEs

If we wish to train a neural network to predict the weather or any other time series dataset, we can use a Neural ODE. A Neural ODE replaces the rhs of an ODE with a neural network.

$$ \frac{dy}{dt} = f(y, p) $$

Here, \(y(t)\) is the state of the system at time \(t\), and \(f\) is a neural network with parameters \(p\). The neural network is trained to predict the derivative of the state, and the ODE solver is used to integrate the state forward in time, and to calculate gradients of the loss function with respect to the parameters of the neural network.

In this example, we will duplicate the weather prediction example from the excellent blog post by Sebastian Callh, but instead using DiffSol as the solver. We'll skip over some of the details, but you can read more details about the problem setup in the original blog post, and see the full code in the DiffSol repository.

First we'll need a neural network model, and we'll use Equinox and JAX for this. We'll define a simple neural network with 3 layers like so

class NeuralNetwork(eqx.Module):
    layers: list

    def __init__(self, data_dim, key):
        key1, key2, key3 = jax.random.split(key, 3)
        self.layers = [
            eqx.nn.Linear(data_dim, 64, key=key1),
            eqx.nn.Linear(64, 32, key=key2),
            eqx.nn.Linear(32, data_dim, key=key3),
        ]

    def __call__(self, x):
        x = jax.nn.silu(self.layers[0](x))  # Swish = SiLU
        x = jax.nn.silu(self.layers[1](x))
        x = self.layers[2](x)
        return x

We will then create four JAX functions that will allow us to calculate:

  • the rhs function \(f(y, p)\) of the Neural ODE, where \(y\) is the state of the system and \(p\) are the parameters.
  • the Jacobian-vector product of the rhs function with respect to the state \(y\).
  • the negative vector-Jacobian product of the rhs function with respect to the state \(y\).
  • the negative vector-Jacobian product of the rhs function with respect to the parameters \(p\).

We will need all four of these to define the ODE problem and to solve it using DiffSol.

key = jax.random.PRNGKey(0)
model = NeuralNetwork(data_dim=data_dim, key=key)
y = jnp.zeros((data_dim,))
v = jnp.zeros((data_dim,))
params, static = eqx.partition(model, eqx.is_array)
p, unravel_params = ravel_pytree(params)


def rhs(p, y):
    params = unravel_params(p)
    model = eqx.combine(params, static)
    return model(y)


def rhs_jac_mul(p, y, v):
    return jax.jvp(ft.partial(rhs, p), (y,), (v,))[1]


def rhs_jac_transpose_mul(p, y, v):
    return -jax.vjp(ft.partial(rhs, p), y)[1](v)[0]


def rhs_sens_transpose_mul(p, y, v):
    return -jax.vjp(ft.partial(rhs, y=y), p)[1](v)[0]

Finally, we can export all four of these JAX functions to ONNX, which will allow us to use them within rust.

def to_onnx(model, inputs, filename):
    sig = [tf.TensorSpec(inpt[0].shape, inpt[0].dtype, name=inpt[1]) for inpt in inputs]
    inference_tf = jax2tf.convert(model, enable_xla=False)
    inference_tf = tf.function(inference_tf, autograph=False)
    inference_onnx = tf2onnx.convert.from_function(inference_tf, input_signature=sig)
    model_proto, _external_tensor_storage = inference_onnx
    with open(filename, "wb") as f:
        f.write(model_proto.SerializeToString())
    return model_proto

Within rust now, we can define a DiffSol system of equations by creating a struct NeuralOde. We'll use the ort crate and the ONNX Runtime to load the ONNX models that we made in Python.

struct NeuralOde {
    rhs: Session,
    rhs_jac_mul: Session,
    rhs_jac_transpose_mul: Session,
    rhs_sens_transpose_mul: Session,
    input_y: RefCell<Array1<f32>>,
    input_v: RefCell<Array1<f32>>,
    input_p: Array1<f32>,
    y0: V,
}

impl NeuralOde {
    fn new_session(filename: &str) -> Result<Session> {
        let full_filename = format!("{}{}", BASE_MODEL_DIR, filename);
        let session = Session::builder()?
            .with_optimization_level(GraphOptimizationLevel::Level3)?
            .with_intra_threads(4)?
            .commit_from_file(full_filename.as_str())?;
        Ok(session)
    }
    fn new(y0: V) -> Result<Self> {
        let rhs = Self::new_session("rhs.onnx")?;
        let rhs_jac_mul = Self::new_session("rhs_jac_mul.onnx")?;
        let rhs_jac_transpose_mul = Self::new_session("rhs_jac_transpose_mul.onnx")?;
        let rhs_sens_transpose_mul = Self::new_session("rhs_sens_transpose_mul.onnx")?;
        let mut nparams = 0;
        for input in rhs.inputs.iter() {
            if input.name == "p" {
                nparams = input.input_type.tensor_dimensions().unwrap()[0] as usize;
                break;
            }
        }
        let mut rng = rand::rng();
        let elem = Uniform::<f32>::new(0.0, 1.0).unwrap();
        let params = Array1::from_shape_fn((nparams,), |_| elem.sample(&mut rng));
        let y0_ndarray = Array1::from_shape_fn((y0.len(),), |i| y0[i] as f32);

        Ok(Self {
            y0,
            rhs,
            rhs_jac_mul,
            rhs_jac_transpose_mul,
            rhs_sens_transpose_mul,
            input_p: params,
            input_v: RefCell::new(y0_ndarray.clone()),
            input_y: RefCell::new(y0_ndarray),
        })
    }

    fn data_dim(&self) -> usize {
        self.y0.len()
    }
}

We'll also implement the OdeSystemAdjoint trait for NeuralOde, which will allow us to use the adjoint method to calculate gradients of out loss function with respect to the parameters of the neural network. As an example, here is the implementation of the NonLinearOp trait:

impl NonLinearOp for Rhs<'_> {
    fn call_inplace(&self, x: &Self::V, _t: Self::T, y: &mut Self::V) {
        let mut y_input = self.0.input_y.borrow_mut();
        y_input
            .iter_mut()
            .zip(x.iter())
            .for_each(|(y, x)| *y = *x as f32);
        let outputs = self
            .0
            .rhs
            .run(
                inputs![
                    "p" => self.0.input_p.view(),
                    "y" => y_input.view(),
                ]
                .unwrap(),
            )
            .unwrap();
        let y_data = outputs["Identity_1:0"].try_extract_tensor::<f32>().unwrap();
        y.iter_mut()
            .zip(y_data.as_slice().unwrap())
            .for_each(|(y, x)| *y = *x as f64);
    }
}

We'll also need an optimiser, so we'll write an AdamW algorithm using the definition in the PyTorch documentation as a guide:

impl AdamW {
    fn new(nparams: usize) -> Self {
        let lr = 1e-2;
        let betas = (0.9, 0.999);
        let eps = 1e-8;
        let m = V::zeros(nparams);
        let m_hat = V::zeros(nparams);
        let v = V::zeros(nparams);
        let v_hat = V::zeros(nparams);
        let lambda = 1e-2;
        Self {
            lr,
            betas,
            eps,
            m,
            m_hat,
            v,
            v_hat,
            lambda,
            t: 0,
        }
    }

    fn step(&mut self, params: &mut V, grads: &V) {
        self.t += 1;
        params.mul_assign(1.0 - self.lr * self.lambda);
        self.m.axpy(1.0 - self.betas.0, grads, self.betas.0);
        self.v.axpy(
            1.0 - self.betas.1,
            &grads.component_mul(grads),
            self.betas.1,
        );
        self.m_hat = &self.m / (1.0 - self.betas.0.powi(self.t));
        self.v_hat = &self.v / (1.0 - self.betas.1.powi(self.t));
        params
            .iter_mut()
            .zip(self.v_hat.iter())
            .zip(self.m_hat.iter())
            .for_each(|((params_i, v_hat_i), m_hat_i)| {
                *params_i -= self.lr * m_hat_i / (v_hat_i.sqrt() + self.eps)
            });
    }
}

We'll then define our loss function, which will return the sum of squared errors between the solution and the data points, along with the gradients of the loss function with respect to the parameters. Since the size of the parameter vector is quite large (>2000), we'll use the adjoint method to calculate the gradients.

fn loss_fn(
    problem: &mut OdeSolverProblem<NeuralOde>,
    p: &V,
    ts_data: &[T],
    ys_data: &M,
    g_m: &mut M,
) -> Result<(T, V), DiffsolError> {
    problem.eqn.set_params(p);
    let (c, ys) = problem
        .bdf::<LS>()?
        .solve_dense_with_checkpointing(ts_data, None)?;
    let mut loss = 0.0;
    for j in 0..g_m.ncols() {
        let delta = ys.column(j) - ys_data.column(j);
        loss += delta.dot(&delta);
        let g_m_i = 2.0 * delta;
        g_m.column_mut(j).copy_from(&g_m_i);
    }
    let adjoint_solver = problem.bdf_solver_adjoint::<LS, _>(c, Some(1)).unwrap();
    let soln = adjoint_solver.solve_adjoint_backwards_pass(ts_data, &[g_m])?;
    Ok((loss, soln.into_common().sg.pop().unwrap()))
}

Finally, we can train the neural network to predict the weather. Following the example given in the linked blog post above, we'll train in stages by increasing the number of datapoints by four each time. Each time we'll train for 150 steps using the AdamW optimiser.

fn train_one_round(
    problem: &mut OdeSolverProblem<NeuralOde>,
    ts_data: &[T],
    ys_data: &M,
    p: &mut V,
) {
    let mut gm = M::zeros(problem.eqn.nout(), ts_data.len());
    let mut adam = AdamW::new(problem.eqn.nparams());
    for _ in 0..150 {
        match loss_fn(problem, p, ts_data, ys_data, &mut gm) {
            Ok((loss, g)) => {
                println!("loss: {}", loss);
                adam.step(p, &g)
            }
            Err(e) => {
                panic!("{}", e);
            }
        };
    }
}

To give an indication of the results, we'll plot the results after we've used the first 20 data-points to train the model, and we'll predict the model solution to the entire dataset.

This seems to work well, and is good at matching the data points a long way into the future. This has been a whirlwind description of both Neural ODEs and this particular analysis. For a more detailed explanation, please refer to the original blog post by Sebastian Callh. We've also skipped over many more boring parts of the code, and you can see the full code for this example in the DiffSol repository.