Skip to content

ML models in observables

Mechanistic models can be misspecified, or the mapping from model states to measurements may be only partially known. Both scenarios can be addressed by augmenting the observable formula in PEtabObservable with a neural network.

This tutorial shows how to include an ML model in the observable formula. It assumes familiarity with the SciML starter tutorial. As a running example, the Michaelis-Menten model from the mechanistic starting tutorial is used:

julia
using Catalyst
t = Catalyst.default_t()
sys = @reaction_network begin
    @parameters S0 c3=3.0
    @species begin
        S(t) = S0
        E(t) = 50.0
        SE(t) = 0.1
        P(t) = 0.1
    end
    c1, S + E --> SE
    c2, SE --> S + E
    c3, SE --> P + E
end

Defining ML models in observable formulas

An ML model can be embedded in the observable formula of a PEtabObservable by (1) defining a Lux.jl model and (2) wrapping it as an MLModel, where its inputs and an output variable are declared for use in observable formulas. For example, assume the ML model takes the states S and E as input:

julia
using Lux, PEtab
lux_model = Lux.Chain(
    Dense(2 => 5, Lux.swish),
    Dense(5 => 1),
)

@variables S(t) E(t)
ml_model = MLModel(
    :net1, lux_model, false; inputs = [S, E], outputs = [:output1]
)
MLModel net1
  mode: simulation
  parameters: 21
  inputs: [S(t), E(t)]
  outputs: [output1]
  hint: see model structure in `ml_model.lux_model`

Here, false indicates that the ML model is not evaluated pre-simulation; instead it is evaluated when observables are computed. The output variable output1 can then be used in observable formulas:

julia
@variables P(t) output1(t)
@parameters sigma
observables = [
    PEtabObservable(:obs_p, P, 3.0),
    PEtabObservable(:obs_sum, output1, sigma),
]
2-element Vector{PEtabObservable}:
 PEtabObservable obs_p: data ~ Normal(μ=P(t), σ=3.0)
 PEtabObservable obs_sum: data ~ Normal(μ=output1(t), σ=sigma)

When an ML model appears in an observable, the observable formula should be defined in PEtabObservable (rather than in the model system). This allows PEtab.jl to compute gradients more efficiently. Also, while the ML inputs are states here, they can also be general expressions of model quantities.

Given the PEtabObservables, the rest of the PEtabODEProblem is created as usual:

julia
using DataFrames
pest = [
    PEtabParameter(:c1),
    PEtabParameter(:c2),
    PEtabParameter(:S0),
    PEtabParameter(:sigma),
    PEtabMLParameter(:net1), # ML parameters
]

measurements = DataFrame(
    obs_id = ["obs_p", "obs_sum", "obs_p", "obs_sum"],
    time = [1.0, 10.0, 1.0, 20.0],
    measurement = [0.7, 0.1, 1.0, 1.5],
)

petab_model = PEtabModel(
    sys, observables, measurements, pest; ml_models = ml_model
)
petab_prob = PEtabODEProblem(petab_model)
PEtabODEProblem ReactionSystemModel: 25 parameters to estimate
(for more statistics, call `describe(petab_prob)`)

Simulation condition-specific inputs

If additional informative non-time-series data (e.g. images or other covariates) are available per simulation condition, they can be included by giving the ML model multiple inputs: one based on model quantities (e.g. states) and one provided via PEtabCondition. The general approach for multiple input arguments is described in Pre-simulation ML models; this section focuses on the observable case.

Assume in the example above that the ML model takes the states [S, E], as well as static simulation-condition-specific data provided via input2. The first step is to define the MLModel:

julia
lux_model = @compact(
    layer1 = Dense(6 => 5, Lux.swish),
    layer2 = Dense(5 => 1),
) do (x1, x2)
    x = cat(x1, x2; dims = 1)
    h = layer1(x)
    out = layer2(h)
    @return out
end

ml_model = MLModel(
    :net1, lux_model, false; inputs = ([S, E], [:input2]), outputs = [:output1]
)

The variable input2 can then be assigned in PEtabCondition (random data are used here for illustration):

julia
simulation_conditions = [
    PEtabCondition(:cond1, :input2 => rand(4)),
    PEtabCondition(:cond2, :input2 => rand(4)),
]

The PEtabODEProblem can then be built as usual:

julia
measurements = DataFrame(
    simulation_id = ["cond1", "cond1", "cond2", "cond2"],
    obs_id        = ["petab_obs2", "petab_obs1", "petab_obs2", "petab_obs1"],
    time          = [5.0, 10.0, 1.0, 20.0],
    measurement   = [0.7, 0.1, 1.0, 1.5],
)

petab_model = PEtabModel(
    sys, observables, measurements, pest; ml_models = ml_model,
    simulation_conditions = simulation_conditions,
)
petab_prob = PEtabODEProblem(petab_model)
PEtabODEProblem ReactionSystemModel: 45 parameters to estimate
(for more statistics, call `describe(petab_prob)`)