Scalable Exact QEP Posterior Sampling using Contour Integral Quadrature

This notebook demonstrates the most simple usage of contour integral quadrature with msMINRES as described here to sample from the predictive distribution of an exact QEP.

Note that to achieve results where Cholesky would run the GPU out of memory, you’ll need to have KeOps installed (see our KeOps tutorial in this same folder). Despite this, on this relatively simple example with 1000 training points but seeing to sample at 20000 test points in 1D, we will achieve significant speed ups over Cholesky.

[1]:
import math
import torch
import qpytorch
from matplotlib import pyplot as plt

import warnings
warnings.simplefilter("ignore", qpytorch.utils.warnings.NumericalWarning)

%matplotlib inline
%load_ext autoreload
%autoreload 2
[2]:
# Training data is 11 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 1000)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

Are we running with KeOps?

If you have KeOps, change the below flag to True to run with a significantly larger test set.

[3]:
HAVE_KEOPS = True

Define an Exact QEP Model and train

[4]:
POWER = 1.0
class ExactQEPModel(qpytorch.models.ExactQEP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)
        self.power = torch.tensor(POWER)
        self.mean_module = qpytorch.means.ConstantMean()

        if HAVE_KEOPS:
            self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.keops.RBFKernel())
        else:
            self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)

# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = ExactQEPModel(train_x, train_y, likelihood)
[5]:
if torch.cuda.is_available():
    train_x = train_x.cuda()
    train_y = train_y.cuda()
    model = model.cuda()
    likelihood = likelihood.cuda()
[6]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50

# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)  # Includes QExponentialLikelihood parameters

# "Loss" for QEPs - the marginal log likelihood
mll = qpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

for i in range(training_iter):
    # Zero gradients from previous iteration
    optimizer.zero_grad()
    # Output from model
    output = model(train_x)
    # Calc loss and backprop gradients
    loss = -mll(output, train_y)
    loss.backward()
    print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
        i + 1, training_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()
Iter 1/50 - Loss: 2.124   lengthscale: 0.693   noise: 0.693
Iter 2/50 - Loss: 2.066   lengthscale: 0.644   noise: 0.644
Iter 3/50 - Loss: 1.998   lengthscale: 0.598   noise: 0.598
Iter 4/50 - Loss: 1.924   lengthscale: 0.554   noise: 0.554
Iter 5/50 - Loss: 1.846   lengthscale: 0.513   noise: 0.513
Iter 6/50 - Loss: 1.781   lengthscale: 0.474   noise: 0.473
Iter 7/50 - Loss: 1.724   lengthscale: 0.437   noise: 0.437
Iter 8/50 - Loss: 1.679   lengthscale: 0.404   noise: 0.402
Iter 9/50 - Loss: 1.642   lengthscale: 0.374   noise: 0.370
Iter 10/50 - Loss: 1.615   lengthscale: 0.348   noise: 0.340
Iter 11/50 - Loss: 1.587   lengthscale: 0.325   noise: 0.312
Iter 12/50 - Loss: 1.563   lengthscale: 0.305   noise: 0.287
Iter 13/50 - Loss: 1.542   lengthscale: 0.288   noise: 0.263
Iter 14/50 - Loss: 1.518   lengthscale: 0.274   noise: 0.241
Iter 15/50 - Loss: 1.499   lengthscale: 0.261   noise: 0.221
Iter 16/50 - Loss: 1.470   lengthscale: 0.250   noise: 0.202
Iter 17/50 - Loss: 1.453   lengthscale: 0.240   noise: 0.185
Iter 18/50 - Loss: 1.429   lengthscale: 0.232   noise: 0.169
Iter 19/50 - Loss: 1.408   lengthscale: 0.224   noise: 0.155
Iter 20/50 - Loss: 1.381   lengthscale: 0.218   noise: 0.141
Iter 21/50 - Loss: 1.368   lengthscale: 0.212   noise: 0.129
Iter 22/50 - Loss: 1.343   lengthscale: 0.207   noise: 0.117
Iter 23/50 - Loss: 1.328   lengthscale: 0.202   noise: 0.107
Iter 24/50 - Loss: 1.305   lengthscale: 0.199   noise: 0.098
Iter 25/50 - Loss: 1.278   lengthscale: 0.195   noise: 0.089
Iter 26/50 - Loss: 1.252   lengthscale: 0.193   noise: 0.081
Iter 27/50 - Loss: 1.231   lengthscale: 0.190   noise: 0.073
Iter 28/50 - Loss: 1.213   lengthscale: 0.188   noise: 0.067
Iter 29/50 - Loss: 1.189   lengthscale: 0.187   noise: 0.061
Iter 30/50 - Loss: 1.162   lengthscale: 0.186   noise: 0.055
Iter 31/50 - Loss: 1.137   lengthscale: 0.185   noise: 0.050
Iter 32/50 - Loss: 1.112   lengthscale: 0.184   noise: 0.046
Iter 33/50 - Loss: 1.102   lengthscale: 0.184   noise: 0.041
Iter 34/50 - Loss: 1.068   lengthscale: 0.184   noise: 0.037
Iter 35/50 - Loss: 1.047   lengthscale: 0.184   noise: 0.034
Iter 36/50 - Loss: 1.022   lengthscale: 0.184   noise: 0.031
Iter 37/50 - Loss: 1.005   lengthscale: 0.184   noise: 0.028
Iter 38/50 - Loss: 0.984   lengthscale: 0.185   noise: 0.025
Iter 39/50 - Loss: 0.957   lengthscale: 0.186   noise: 0.023
Iter 40/50 - Loss: 0.932   lengthscale: 0.187   noise: 0.021
Iter 41/50 - Loss: 0.910   lengthscale: 0.188   noise: 0.019
Iter 42/50 - Loss: 0.881   lengthscale: 0.189   noise: 0.017
Iter 43/50 - Loss: 0.866   lengthscale: 0.190   noise: 0.016
Iter 44/50 - Loss: 0.840   lengthscale: 0.192   noise: 0.014
Iter 45/50 - Loss: 0.818   lengthscale: 0.194   noise: 0.013
Iter 46/50 - Loss: 0.788   lengthscale: 0.196   noise: 0.012
Iter 47/50 - Loss: 0.771   lengthscale: 0.198   noise: 0.011
Iter 48/50 - Loss: 0.753   lengthscale: 0.201   noise: 0.010
Iter 49/50 - Loss: 0.727   lengthscale: 0.203   noise: 0.009
Iter 50/50 - Loss: 0.707   lengthscale: 0.206   noise: 0.008

Define test set

If we have KeOps installed, we’ll test on 5000 points instead of 1000.

[7]:
if HAVE_KEOPS:
    test_n = 5000
else:
    test_n = 1000

test_x = torch.linspace(0, 1, test_n)
if torch.cuda.is_available():
    test_x = test_x.cuda()
print(test_x.shape)
torch.Size([5000])

Draw a sample with CIQ

To do this, we just add the ciq_samples setting to the rsample call. We additionally demonstrate all relevant settings for controlling Contour Integral Quadrature:

  • The ciq_samples setting determines whether or not to use CIQ

  • The num_contour_quadrature setting controls the number of quadrature sites (Q in the paper).

  • The minres_tolerance setting controls the error we tolerate from minres (here, <0.01%).

Note that, of these settings, increase num_contour_quadrature is unlikely to improve performance. As Theorem 1 from the paper demonstrates, virtually all of the error in this method is controlled by minres_tolerance. Here, we use a quite tight tolerance for minres.

[8]:
import time

model.train()
likelihood.train()

# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()

# Test points are regularly spaced along [0,1]
# Make predictions by feeding model through likelihood

test_x.requires_grad_(True)

with torch.no_grad():
    observed_pred = likelihood(model(test_x))

    # All relevant settings for using CIQ.
    #   ciq_samples(True) - Use CIQ for sampling
    #   num_contour_quadrature(10) -- Use 10 quadrature sites (Q in the paper)
    #   minres_tolerance -- error tolerance from minres (here, <0.01%).
    print("Running with CIQ")
    with qpytorch.settings.ciq_samples(True), qpytorch.settings.num_contour_quadrature(10), qpytorch.settings.minres_tolerance(1e-4):
        %time y_samples = observed_pred.rsample()

    print("Running with Cholesky")
    # Make sure we use Cholesky
    with qpytorch.settings.fast_computations(covar_root_decomposition=False):
        %time y_samples = observed_pred.rsample()
Running with CIQ
CPU times: user 26.3 s, sys: 498 ms, total: 26.8 s
Wall time: 4.56 s
Running with Cholesky
CPU times: user 1min 10s, sys: 1.7 s, total: 1min 11s
Wall time: 12.5 s