Converting Exact QEP Models to TorchScript¶
In this notebook, we’ll demonstrate converting an Exact QEP model to TorchScript. In general, this is the same as for standard PyTorch models where we’ll use torch.jit.trace, but there are two pecularities to keep in mind for QPyTorch:
The first time you make predictions with a QPyTorch model (exact or approximate), we cache certain computations. These computations can’t be traced, but the results of them can be. Therefore, we’ll need to pass data through the untraced model once, and then trace the model.
For exact QEPs, we can’t trace models unless
qpytorch.settings.fast_pred_varis used. This is a technical issue that may not be possible to overcome due to limitations on what can be traced in PyTorch; however, if you really need to trace a QEP but can’t use the above setting, open an issue so we have visibility on there being demand for this.You can’t trace models that return Distribution objects. Therefore, we’ll write a simple wrapper than unpacks the MultivariateQExponential that our QEPs return in to just a mean and variance tensor.
Define and train an exact QEP¶
In the next cell, we define some data, define a QEP model and train it. Nothing new here – pretty much just move on to the next cell after this one.
[2]:
import math
import torch
import qpytorch
from matplotlib import pyplot as plt
%matplotlib inline
%load_ext autoreload
%autoreload 2
# Training data is 100 points in [0,1] inclusive regularly spaced
train_x = torch.linspace(0, 1, 100)
# True function is sin(2*pi*x) with Gaussian noise
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2
# We will use the simplest form of QEP model, exact inference
POWER = 1.0
class ExactQEPModel(qpytorch.models.ExactQEP):
def __init__(self, train_x, train_y, likelihood):
super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)
self.power = torch.tensor(POWER)
self.mean_module = qpytorch.means.ConstantMean()
self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)
# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood()
model = ExactQEPModel(train_x, train_y, likelihood)
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iter = 2 if smoke_test else 50
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for QEPs - the marginal log likelihood
mll = qpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
# Zero gradients from previous iteration
optimizer.zero_grad()
# Output from model
output = model(train_x)
# Calc loss and backprop gradients
loss = -mll(output, train_y)
loss.backward()
optimizer.step()
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
Trace the Model¶
In the next cell, we trace our QEP model. To overcome the fact that we can’t trace Modules that return Distributions, we write a wrapper Module that unpacks the QEP output in to a mean and variance.
Additionally, we’ll need to run with the qpytorch.settings.trace_mode setting enabled, because PyTorch can’t trace custom autograd Functions. Note that this results in some inefficiencies.
Then, before calling torch.jit.trace we first call the model on test_x. This step is required, as it does some precomputation using torch functionality that cannot be traced.
[4]:
class MeanVarModelWrapper(torch.nn.Module):
def __init__(self, qep):
super().__init__()
self.qep = qep
def forward(self, x):
output_dist = self.qep(x)
return output_dist.mean, output_dist.variance
with torch.no_grad(), qpytorch.settings.fast_pred_var(), qpytorch.settings.trace_mode():
model.eval()
test_x = torch.linspace(0, 1, 51)
pred = model(test_x) # Do precomputation
traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:315: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/kernels/kernel.py:511: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not x1_.size(-1) == x2_.size(-1):
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:366: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if res.shape != self.shape:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/linear_operator/operators/_linear_operator.py:1417: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
elif not self.is_square:
/Users/shiweilan/Projects/QPyTorch/qpytorch/distributions/multivariate_qexponential.py:425: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if variance.lt(min_variance).any():
Compare Predictions from TorchScript model and Torch model¶
[5]:
with torch.no_grad():
traced_mean, traced_var = traced_model(test_x)
print(torch.norm(traced_mean - pred.mean))
print(torch.norm(traced_var - pred.variance))
tensor(0.)
tensor(0.)
[6]:
traced_model.save('traced_exact_qep.pt')