Spatial Modelling: Linear Regression vs. Gaussian Processes
In this notebook, we explore the benefits of combining structured mean functions with Gaussian Processes (GPs) for modelling spatial data. We will compare two approaches: 1. A Baseline Linear Model: A standard Bayesian linear regression that assumes the target variable is a linear combination of the inputs. 2. A Joint Linear + GP Model: A semi-parametric model that captures the global linear trend while using a GP to model the non-linear spatial residuals.
Crucially, this example demonstrates the seamless integration between GPJax and NumPyro.
We will show how GPJax defines the GP prior and likelihood, while NumPyro handles the
Hamiltonian Monte Carlo (HMC) inference for both the linear coefficients and the GP
hyperparameters simultaneously.
1. Setup and Data Simulation
First, we import the necessary libraries. We enable 64-bit precision in JAX to ensure numerical stability during matrix decompositions.
We simulate a 2D spatial dataset (\(N=200\)) on a domain \([0, 5] imes [0, 5]\). The true generating process consists of: * A Linear Trend: \(y_{\text{lin}} = 2x_1 - 1x_2 + 1.5\) * A Spatial Residual: \(y_{\text{res}} = \sin(x_1) \cos(x_2)\) * Observation Noise: \(\epsilon \sim \mathcal{N}(0, 0.1^2)\)
This structure effectively masks the non-linear signal within a dominant linear trend, posing a challenge for simple linear models.
import jax
import jax.numpy as jnp
import jax.random as jr
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive
import gpjax as gpx
from gpjax.numpyro_extras import register_parameters
import matplotlib.pyplot as plt
# Enable x64 precision for better stability
jax.config.update("jax_enable_x64", True)
print("Spatial Linear GP Comparison Example")
# --- Step 2: Data Simulation ---
N = 200
key = jr.key(42)
key_x, key_noise = jr.split(key)
# Simulate X in [0, 5] x [0, 5]
X = jr.uniform(key_x, shape=(N, 2), minval=0.0, maxval=5.0)
# True Linear Trend
true_slope = jnp.array([2.0, -1.0])
true_intercept = 1.5
y_lin = X @ true_slope + true_intercept
# Non-linear Spatial Residual
y_res = jnp.sin(X[:, 0]) * jnp.cos(X[:, 1])
# Total Signal + Noise
y_clean = y_lin + y_res
noise_std = 0.1
y = y_clean + noise_std * jr.normal(key_noise, shape=y_clean.shape)
print(f"Generated {N} data points.")
Spatial Linear GP Comparison Example
Generated 200 data points.
2. Baseline Linear Model
We begin by defining a standard Bayesian linear regression model in NumPyro. This model assumes that the data can be fully explained by a hyperplane and Gaussian noise.
We use the No-U-Turn Sampler (NUTS) to estimate the posterior distributions of the slope \(\mathbf{w}\), intercept \(b\), and noise \(\sigma\).
def linear_model(X, Y=None):
# Priors
slope = numpyro.sample("slope", dist.Normal(0.0, 5.0).expand([2]))
intercept = numpyro.sample("intercept", dist.Normal(0.0, 5.0))
obs_noise = numpyro.sample("obs_noise", dist.LogNormal(0.0, 1.0))
# Mean function
mu = X @ slope + intercept
numpyro.deterministic("mu", mu)
# Likelihood
numpyro.sample("obs", dist.Normal(mu, obs_noise), obs=Y)
# Run MCMC for Linear Model
print("\nRunning MCMC for Baseline Linear Model...")
nuts_kernel_lin = NUTS(linear_model)
mcmc_lin = MCMC(nuts_kernel_lin, num_warmup=500, num_samples=1000, num_chains=1)
mcmc_lin.run(key, X, y)
mcmc_lin.print_summary()
Running MCMC for Baseline Linear Model...
mean std median 5.0% 95.0% n_eff r_hat
intercept 1.43 0.10 1.43 1.27 1.58 643.95 1.00
obs_noise 0.51 0.03 0.51 0.46 0.55 708.85 1.00
slope[0] 2.06 0.02 2.06 2.02 2.10 701.62 1.00
slope[1] -1.05 0.02 -1.04 -1.09 -1.01 732.54 1.00
Number of divergences: 0
3. Joint Linear + GP Model
Now we define the joint model. Here, the GP accounts for the residuals that the linear model cannot explain.
GPJax and NumPyro Integration
This section highlights the interoperability between GPJax and NumPyro.
- GP Definition: We define the GP prior in
GPJaxusing an RBF kernel and a zero mean function (since the linear trend is handled explicitly). We attachdist.LogNormalpriors to the kernel's hyperparameters (lengthscale and variance) directly within the GPJax object. register_parameters: Inside the NumPyro model, we callgpx.numpyro_extras.register_parameters(gp_posterior). This function traverses the GPJax object, identifies parameters with attached priors, and registers them as NumPyro sample sites. It returns a new GPJax object where the parameters have been replaced by the values sampled by NumPyro.- Conjugate Marginal Log-Likelihood: We compute the exact marginal log-likelihood (MLL) of
the residuals under the GP prior using
gpx.objectives.conjugate_mll. This term is added to the potential function usingnumpyro.factor, guiding the sampler.
# GP Definition
lengthscale = gpx.parameters.PositiveReal(1.0, prior=dist.LogNormal(0.0, 1.0))
variance = gpx.parameters.PositiveReal(1.0, prior=dist.LogNormal(0.0, 1.0))
# active_dims=[0, 1] ensures the kernel operates on both spatial dimensions
kernel = gpx.kernels.RBF(active_dims=[0, 1], lengthscale=lengthscale, variance=variance)
meanf = gpx.mean_functions.Zero()
prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel)
obs_stddev = gpx.parameters.NonNegativeReal(0.1, prior=dist.LogNormal(0.0, 1.0))
likelihood = gpx.likelihoods.Gaussian(num_datapoints=N, obs_stddev=obs_stddev)
gp_posterior = prior * likelihood
def joint_model(X, Y, gp_posterior, X_new=None):
# 1. Sample Linear Model Parameters
slope = numpyro.sample("slope", dist.Normal(0.0, 5.0).expand([2]))
intercept = numpyro.sample("intercept", dist.Normal(0.0, 5.0))
trend = X @ slope + intercept
# 2. Register GP Parameters with NumPyro
# This draws samples for lengthscale, variance, and obs_noise from their priors
p_posterior = register_parameters(gp_posterior)
if Y is not None:
# Calculate residuals for the GP to model
residuals = Y - trend
# Reshape residuals to (N, 1) for GPJax Dataset
residuals = residuals.reshape(-1, 1)
D_resid = gpx.Dataset(X=X, y=residuals)
# 3. Compute GP Marginal Log-Likelihood
mll = gpx.objectives.conjugate_mll(p_posterior, D_resid)
numpyro.factor("gp_log_lik", mll)
if X_new is not None:
# Prediction logic
if Y is not None:
residuals = Y - trend
residuals = residuals.reshape(-1, 1)
D_resid = gpx.Dataset(X=X, y=residuals)
# Compute predictive distribution for the GP component
latent_dist = p_posterior.predict(X_new, train_data=D_resid)
f_new = numpyro.sample("f_new", latent_dist)
f_new = f_new.reshape((-1, 1))
# Combine Linear Trend + GP Residual
total_prediction = (X_new @ slope + intercept).reshape(-1, 1) + f_new
numpyro.deterministic("y_pred", total_prediction)
# Run MCMC for Joint Model
print("\nRunning MCMC for Joint Linear + GP Model...")
# Use a closure to pass the static gp_posterior object to the model
def joint_model_wrapper(X, Y, X_new=None):
joint_model(X, Y, gp_posterior, X_new)
nuts_kernel_joint = NUTS(joint_model_wrapper)
mcmc_joint = MCMC(nuts_kernel_joint, num_warmup=500, num_samples=1000, num_chains=1)
mcmc_joint.run(key, X, y)
mcmc_joint.print_summary()
Running MCMC for Joint Linear + GP Model...
mean std median 5.0% 95.0% n_eff r_hat
intercept 1.52 0.75 1.51 0.41 2.89 602.21 1.00
likelihood.obs_stddev 0.11 0.01 0.11 0.10 0.12 802.34 1.00
prior.kernel.lengthscale 1.73 0.15 1.74 1.45 1.96 474.16 1.00
prior.kernel.variance 0.98 0.59 0.82 0.28 1.70 446.24 1.00
slope[0] 1.97 0.17 1.96 1.68 2.22 579.34 1.00
slope[1] -1.00 0.17 -0.99 -1.28 -0.73 601.57 1.00
Number of divergences: 0
4. Comparison and Visualization
We evaluate both models by comparing their Root Mean Squared Error (RMSE) against the true noise-free signal. We also visualise the predictions over the 2D domain.
We expect the Joint Model to significantly outperform the linear baseline because it can capture the spatial correlations (\(\sin(x_1)\cos(x_2)\)) that the linear model ignores.
# --- Step 6: Comparison & Visualization ---
# 1. Prediction on Training Data (for RMSE)
# Linear Model
samples_lin = mcmc_lin.get_samples()
predictive_lin = Predictive(linear_model, samples_lin, return_sites=["mu"])
preds_lin = predictive_lin(jr.key(1), X=X)["mu"]
mean_pred_lin = jnp.mean(preds_lin, axis=0)
# Joint Model
samples_joint = mcmc_joint.get_samples()
predictive_joint = Predictive(joint_model_wrapper, samples_joint, return_sites=["y_pred"])
preds_joint = predictive_joint(jr.key(2), X=X, Y=y, X_new=X)["y_pred"]
mean_pred_joint = jnp.mean(preds_joint, axis=0)
# Calculate RMSE
rmse_lin = jnp.sqrt(jnp.mean((mean_pred_lin.flatten() - y_clean.flatten())**2))
rmse_joint = jnp.sqrt(jnp.mean((mean_pred_joint.flatten() - y_clean.flatten())**2))
print(f"\nRMSE Comparison (vs True Signal):")
print(f"Linear Model: {rmse_lin:.4f}")
print(f"Joint Model: {rmse_joint:.4f}")
# 2. Visualization on a Grid
n_grid = 30
x1 = jnp.linspace(0, 5, n_grid)
x2 = jnp.linspace(0, 5, n_grid)
X1, X2 = jnp.meshgrid(x1, x2)
X_grid = jnp.column_stack([X1.ravel(), X2.ravel()])
# True Signal on Grid
y_grid_true = (X_grid @ true_slope + true_intercept) + (jnp.sin(X_grid[:, 0]) * jnp.cos(X_grid[:, 1]))
# Linear Prediction on Grid
preds_lin_grid = predictive_lin(jr.key(3), X=X_grid)["mu"]
mean_pred_lin_grid = jnp.mean(preds_lin_grid, axis=0)
# Joint Prediction on Grid
preds_joint_grid = predictive_joint(jr.key(4), X=X, Y=y, X_new=X_grid)["y_pred"]
mean_pred_joint_grid = jnp.mean(preds_joint_grid, axis=0)
# Plotting
fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)
# Truth
c0 = axes[0].tricontourf(X_grid[:,0], X_grid[:,1], y_grid_true, levels=20, cmap='viridis')
axes[0].set_title("True Signal")
plt.colorbar(c0, ax=axes[0])
# Linear
c1 = axes[1].tricontourf(X_grid[:,0], X_grid[:,1], mean_pred_lin_grid.flatten(), levels=20, cmap='viridis')
axes[1].set_title(f"Linear Model (RMSE: {rmse_lin:.2f})")
plt.colorbar(c1, ax=axes[1])
# Joint
c2 = axes[2].tricontourf(X_grid[:,0], X_grid[:,1], mean_pred_joint_grid.flatten(), levels=20, cmap='viridis')
axes[2].set_title(f"Joint Model (RMSE: {rmse_joint:.2f})")
plt.colorbar(c2, ax=axes[2])
for ax in axes:
ax.set_xlabel("x1")
ax.set_ylabel("x2")
ax.scatter(X[:,0], X[:,1], c='k', s=10, alpha=0.3, label="Data")
plt.tight_layout()
RMSE Comparison (vs True Signal):
Linear Model: 0.4844
Joint Model: 0.0309
