Skip to main content

Notes on Probabilistic Programming for Stochastic Differential Equations

·

Some notes on Bayesian inference for stochastic differential equations in Julia. Specifically, inference for θ and σ of the Ornstein–Uhlenbeck process. The explanation is quite terse, since in the end, I was not unable to get this to work on larger problems.

Generating the true data:

using DifferentialEquations
using DiffEqNoiseProcess
using Turing
using Distributions
using LinearAlgebra
using Plots, StatsPlots
using Random; Random.seed!(85455)

function f!(du, u, p, t)
    du[1] = -p[1]*u[1]
end
function g!(du, u, p, t)
    du[1] = p[2]
end
u0 = [10.0]
t0,tend = 0.0,1.0
Δt = 0.1
tspan = (t0,tend)
tmeasure = t0:Δt:tend
p = [1.0, 3.0]
W = WienerProcess(0.0, zeros(length(u0)))
prob = SDEProblem(f!,g!,u0,tspan,p,saveat=tmeasure,noise=W)

sol = solve(prob)
data  = sol[1,:] # needs to be vectorized for Turing
plot(sol)

We see that this solve is indeed random:

plot(solve(prob))
plot()
for _ in 1:1000
    sol = solve(prob)
    plot!(sol)
end
plot!(legend=false)

The SDE solvers discretize the process noise (W) in time. However, by default, the SDE solvers are adaptive, and thus evaluate the process noise, at different time-points for each solve. Turing can, however, only perform inference if the number of random variables and their meaning are the same in each iteration of the MC-MC algorithm.

Luckily, DifferentialEquations.jl allows you to provide an already discretized version of the process noise to the solver, using NoiseGrid. The price you pay is that the process noise is interpolated linearly between the grid points, which is less accurate than if you let the solver choose its own discretization.

To generate the noise, we use the matrix normal distribution. Different columns in this matrix represent independent realizations of the process noise for different time-intervals.

noise_per_interval = MatrixNormal(zeros(length(u0),length(tmeasure)-1),
                                  diagm(Δt*ones(length(u0))),
                                  diagm(ones(length(tmeasure)-1)))
noise_per_interval = rand(noise_per_interval)
brownian_noise = hcat(zeros(length(u0)), cumsum(noise_per_interval,dims=2))
brownian_noise_aoa = [collect(c) for c in eachcol(brownian_noise)]
# needs to be ArrayOfArray for DifferentialEquations.jl
W = NoiseGrid(vcat(tmeasure,tend+Δt),vcat(brownian_noise_aoa,[rand(length(u0))]))
# For some reason NoiseGrid needs an additional time-point after simulation has already ended.
prob = remake(prob,noise=W)
sol = solve(prob)
plot(sol)
plot()
for _ in 1:1000
    noise_per_interval = MatrixNormal(zeros(length(u0),length(tmeasure)-1),
                                      diagm(Δt*ones(length(u0))),
                                      diagm(ones(length(tmeasure)-1)))
    noise_per_interval = rand(noise_per_interval)
    brownian_noise = hcat(zeros(length(u0)), cumsum(noise_per_interval,dims=2))
    brownian_noise_aoa = [collect(c) for c in eachcol(brownian_noise)]
    W = NoiseGrid(vcat(tmeasure,tend+Δt),vcat(brownian_noise_aoa,[rand(length(u0))]))
    global prob = remake(prob,noise=W)
    sol = solve(prob)
    plot!(sol)
end
plot!(legend=false)

Now let us do inference using the Metropolis-Hastings algorithm.

@model function model(data, prob)
    # Prior distributions for parameters of interest
    θ ~ Uniform(0.1,10.0)
    σ ~ Uniform(0.1,10.0)
    p = [θ,σ]
    # Prior distribution for process noise at measurement times
    noise_per_interval ~ MatrixNormal(zeros(length(u0),
                                      length(tmeasure)-1), diagm(Δt*ones(length(u0))),
                                      diagm(ones(length(tmeasure)-1)))
    brownian_noise = hcat(zeros(length(u0)), cumsum(noise_per_interval,dims=2))
    brownian_noise_aoa = [collect(c) for c in eachcol(brownian_noise)]
    W = NoiseGrid(vcat(tmeasure,tend+Δt),vcat(brownian_noise_aoa,[rand(length(u0))]))
    # simulating the system
    prob = remake(prob,p=p,noise=W)
    sol = solve(prob)
    # likelihood
    data ~ MvNormal(sol[1,:],fill(sqrt(0.1) ,length(sol))) # Chain gets stuck without extra noise.
    return nothing
end
chain = sample(model(data, prob), MH(), MCMCThreads(), 1_0_000, 8, progress=false);
plot(chain[:,[,],:])

We see that θ is recovered quite nicely, but that the chains do not completely agree on σ.

I was unable to use more modern MCMC algorithms, such as NUTS, since these require derivatives of the likelihood, both towards the model parameters, θ and σ, as well as to the discretized process noise. Such derivatives cannot yet be calculated by StochasticDifferentialEquations.jl. Thus, this method will not scale to more complicated problems, with more states, parameters, and measurement times.