Latent Function Inference with Pyro + QPyTorch (Low-Level Interface)

Overview

In this example, we will give an overview of the low-level Pyro-QPyTorch integration. The low-level interface makes it possible to write QP models in a Pyro-style – i.e. defining your own model and guide functions.

These are the key differences between the high-level and low-level interface:

High level interface

  • Base class is qpytorch.models.PyroQEP.

  • QPyTorch automatically defines the model and guide functions for Pyro.

  • Best used when prediction is the primary goal

Low level interface

  • Base class is qpytorch.models.ApproximateQEP.

  • User defines the model and guide functions for Pyro.

  • Best used when inference is the primary goal

[1]:
import math
import torch
import pyro
import tqdm
import qpytorch
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

This example uses a QEP to infer a latent function \(\lambda(x)\), which parameterises the exponential distribution:

\[y \sim \text{Exponential} (\lambda),\]

where:

\[\lambda = \exp(f) \in (0,+\infty)\]

is a QEP link function, which transforms the latent gaussian process variable:

\[f \sim QEP \in (-\infty,+\infty).\]

In other words, given inputs \(X\) and observations \(Y\) drawn from exponential distribution with \(\lambda = \lambda(X)\), we want to find \(\lambda(X)\).

[2]:
# Here we specify a 'true' latent function lambda
scale = lambda x: np.sin(2 * math.pi * x) + 1

# Generate synthetic data
# here we generate some synthetic samples
NSamp = 100

X = np.linspace(0, 1, 100)

fig, (lambdaf, samples) = plt.subplots(1, 2, figsize=(10, 3))

lambdaf.plot(X,scale(X))
lambdaf.set_xlabel('x')
lambdaf.set_ylabel('$\lambda$')
lambdaf.set_title('Latent function')

Y = np.zeros_like(X)
for i,x in enumerate(X):
    Y[i] = np.random.exponential(scale(x), 1)
samples.scatter(X,Y)
samples.set_xlabel('x')
samples.set_ylabel('y')
samples.set_title('Samples from exp. distrib.')
<>:14: SyntaxWarning: invalid escape sequence '\l'
<>:14: SyntaxWarning: invalid escape sequence '\l'
/var/folders/4m/ffmpvs751fg60zck9593w5sc0000gn/T/ipykernel_75437/2610451266.py:14: SyntaxWarning: invalid escape sequence '\l'
  lambdaf.set_ylabel('$\lambda$')
/var/folders/4m/ffmpvs751fg60zck9593w5sc0000gn/T/ipykernel_75437/2610451266.py:19: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  Y[i] = np.random.exponential(scale(x), 1)
[2]:
Text(0.5, 1.0, 'Samples from exp. distrib.')
../../_images/examples_07_Pyro_Integration_Pyro_QPyTorch_Low_Level_3_2.png
[3]:
train_x = torch.tensor(X).float()
train_y = torch.tensor(Y).float()

Using the low-level Pyro/QPyTorch interface

The low-level iterface should look familiar if you’ve written Pyro models/guides before. We’ll use a qpytorch.models.ApproximateQEP object to model the QEP. To use the low-level interface, this object needs to define 3 functions:

  • forward(x) - which computes the prior QEP mean and covariance at the supplied times.

  • guide(x) - which defines the approximate QEP posterior.

  • model(x) - which does the following 3 things

    • Computes the QEP prior at x

    • Converts QEP function samples into scale function samples, using the link function defined above.

    • Sample from the observed distribution p(y | f). (This takes the place of a qpytorch Likelihood that we would’ve used in the high-level interface).

[4]:
POWER = 2.0
class PVQEPRegressionModel(qpytorch.models.ApproximateQEP):
    def __init__(self, num_inducing=64, name_prefix="mixture_qep"):
        self.power = torch.tensor(POWER)
        self.name_prefix = name_prefix

        # Define all the variational stuff
        inducing_points = torch.linspace(0, 1, num_inducing)
        variational_strategy = qpytorch.variational.VariationalStrategy(
            self, inducing_points,
            qpytorch.variational.CholeskyVariationalDistribution(num_inducing_points=num_inducing, power=self.power)
        )

        # Standard initializtation
        super().__init__(variational_strategy)

        # Mean, covar, likelihood
        self.mean_module = qpytorch.means.ConstantMean()
        self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean, covar, power=self.power)

    def guide(self, x, y):
        # Get q(f) - variational (guide) distribution of latent function
        function_dist = self.pyro_guide(x)

        # Use a plate here to mark conditional independencies
        with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
            # Sample from latent function distribution
            pyro.sample(self.name_prefix + ".f(x)", function_dist)

    def model(self, x, y):
        pyro.module(self.name_prefix + ".qep", self)

        # Get p(f) - prior distribution of latent function
        function_dist = self.pyro_model(x)

        # Use a plate here to mark conditional independencies
        with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
            # Sample from latent function distribution
            function_samples = pyro.sample(self.name_prefix + ".f(x)", function_dist)

            # Use the link function to convert QEP samples into scale samples
            scale_samples = function_samples.exp()

            # Sample from observed distribution
            return pyro.sample(
                self.name_prefix + ".y",
                pyro.distributions.Exponential(scale_samples.reciprocal()),  # rate = 1 / scale
                obs=y
            )
[5]:
model = PVQEPRegressionModel()

Performing inference with Pyro

Unlike all the other examples in this library, PyroQEP models use Pyro’s inference and optimization classes (rather than the classes provided by PyTorch).

If you are unfamiliar with Pyro’s inference tools, we recommend checking out the Pyro SVI tutorial.

[6]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
num_iter = 2 if smoke_test else 200
num_particles = 1 if smoke_test else 256


def train():
    optimizer = pyro.optim.Adam({"lr": 0.01})
    elbo = pyro.infer.Trace_ELBO(num_particles=num_particles, vectorize_particles=True, retain_graph=True)
    svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)

    model.train()
    iterator = tqdm.notebook.tqdm(range(num_iter))
    for i in iterator:
        model.zero_grad()
        loss = svi.step(train_x, train_y)
        iterator.set_postfix(loss=loss, lengthscale=model.covar_module.base_kernel.lengthscale.item())

%time train()
CPU times: user 2min 33s, sys: 40.6 s, total: 3min 13s
Wall time: 33.8 s

In this example, we are only performing inference over the QEP latent function (and its associated hyperparameters). In later examples, we will see that this basic loop also performs inference over any additional latent variables that we define.

Making predictions

For some problems, we simply want to use Pyro to perform inference over latent variables. However, we can also use the models’ (approximate) predictive posterior distribution. Making predictions with a PyroQEP model is exactly the same as for standard QPyTorch models.

[7]:
# Here's a quick helper function for getting smoothed percentile values from samples
def percentiles_from_samples(samples, percentiles=[0.05, 0.5, 0.95]):
    num_samples = samples.size(0)
    samples = samples.sort(dim=0)[0]

    # Get samples corresponding to percentile
    percentile_samples = [samples[int(num_samples * percentile)] for percentile in percentiles]

    # Smooth the samples
    kernel = torch.full((1, 1, 5), fill_value=0.2)
    percentiles_samples = [
        torch.nn.functional.conv1d(percentile_sample.view(1, 1, -1), kernel, padding=2).view(-1)
        for percentile_sample in percentile_samples
    ]

    return percentile_samples
[8]:
# define test set (optionally on GPU)
denser = 2 # make test set 2 times denser then the training set
test_x = torch.linspace(0, 1, denser * NSamp).float()#.cuda()

model.eval()
with torch.no_grad():
    output = model(test_x)

# Get E[exp(f)] via f_i ~ QEP, 1/n \sum_{i=1}^{n} exp(f_i).
# Similarly get the 5th and 95th percentiles
samples = output(torch.Size([1000])).exp()
lower, mean, upper = percentiles_from_samples(samples)

# Draw some simulated y values
scale_sim = model(train_x)().exp()
y_sim = pyro.distributions.Exponential(scale_sim.reciprocal())()
[9]:
# visualize the result
fig, (func, samp) = plt.subplots(1, 2, figsize=(12, 3))
line, = func.plot(test_x, mean.detach().cpu().numpy(), label='GP prediction')
func.fill_between(
    test_x, lower.detach().cpu().numpy(),
    upper.detach().cpu().numpy(), color=line.get_color(), alpha=0.5
)

func.plot(test_x, scale(test_x), label='True latent function')
func.legend()

# sample from p(y|D,x) = \int p(y|f) p(f|D,x) df (doubly stochastic)
samp.scatter(train_x, train_y, alpha = 0.5, label='True train data')
samp.scatter(train_x, y_sim.cpu().detach().numpy(), alpha=0.5, label='Sample from the model')
samp.legend()
[9]:
<matplotlib.legend.Legend at 0x15d39bda0>
../../_images/examples_07_Pyro_Integration_Pyro_QPyTorch_Low_Level_14_1.png

Next steps

This example probably could’ve also been done (slightly easier) using the high-level Pyro integration, or using QPyTorch’s native SVQEP implementation. The low-level interface really comes in handy when a gpytorch Likelihood is difficult to define. For an example of this, see the next example which uses the low-level interface to infer the intensity function of a Cox process.