Converting Variational Models to TorchScript¶
The purpose of this notebook is to demonstrate how to convert a variational QPyTorch model to a ScriptModule that can e.g. be exported to LibTorch.
In general the process is quite similar to standard torch models, where we will trace them using torch.jit.trace. However there are two key differences:
The first time you make predictions with a QPyTorch model (exact or approximate), we cache certain computations. These computations can’t be traced, but the results of them can be. Therefore, we’ll need to pass data through the untraced model once, and then trace the model.
You can’t trace models that return Distribution objects. Therefore, we’ll write a simple wrapper than unpacks the MultivariateNormal that our QEPs return in to just a mean and variance tensor.
Download Data and Define Model¶
In this tutorial, we’ll be tracing an SVQEP model trained for just 10 epochs on the elevators UCI dataset. The next two cells are copied directly from our variational tutorial, and download the data and define the variational QEP model.
[1]:
import torch
import urllib.request
import os
from scipy.io import loadmat
from math import floor
# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)
if not smoke_test and not os.path.isfile('../elevators.mat'):
print('Downloading \'elevators\' UCI dataset...')
urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')
if smoke_test: # this is for running the notebook in our testing framework
X, y = torch.randn(1000, 18), torch.randn(1000)
else:
data = torch.Tensor(loadmat('../elevators.mat')['data'])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]
train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()
test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()
if torch.cuda.is_available():
train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()
[2]:
import qpytorch
from qpytorch.models import ApproximateQEP
from qpytorch.variational import CholeskyVariationalDistribution
from qpytorch.variational import VariationalStrategy
POWER = 1.0
class QEPModel(ApproximateQEP):
def __init__(self, inducing_points):
self.power = torch.tensor(POWER)
variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0), power=self.power)
variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
super(QEPModel, self).__init__(variational_strategy)
self.mean_module = qpytorch.means.ConstantMean()
self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel(ard_num_dims=18))
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)
inducing_points = torch.randn(500, 18)
model = QEPModel(inducing_points=inducing_points)
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=model.power)
if torch.cuda.is_available():
model = model.cuda()
likelihood = likelihood.cuda()
Load a Trained Model¶
To keep things simple for this notebook, we won’t be training here. Instead, we’ll be loading the parameters for a pre-trained model on elevators that we trained in the SVQEP example notebook.
[3]:
if torch.cuda.is_available():
model_state_dict, likelihood_state_dict = torch.load('svqep_elevators.pt')
else:
model_state_dict, likelihood_state_dict = torch.load('svqep_elevators.pt', map_location='cpu')
model.load_state_dict(model_state_dict)
likelihood.load_state_dict(likelihood_state_dict)
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/gpytorch/means/constant_mean.py:20: OldVersionWarning: You have loaded a GP model with a ConstantMean from a previous version of GPyTorch. The mean module parameter `constant` has been renamed to `raw_constant`. Additionally, the shape of `raw_constant` is now *batch_shape, whereas the shape of `constant` was *batch_shape x 1. We have updated the name/shape of the parameter in your state dict, but we recommend that you re-save your model.
warnings.warn(
[3]:
<All keys matched successfully>
Define a Wrapper¶
Instead of directly tracing the QEP, we’ll need to trace a PyTorch Module that returns tensors. In the next cell, we define a wrapper that calls a QEP and then unpacks the resulting Distribution in to a mean and variance.
You could also return the full covariance_matrix if you wanted that rather than the variance.
[4]:
class MeanVarModelWrapper(torch.nn.Module):
def __init__(self, qep):
super().__init__()
self.qep = qep
def forward(self, x):
output_dist = self.qep(x)
return output_dist.mean, output_dist.variance
Trace the Model¶
In the next cell, we trace the model as normal, with the exception that we first pass data through the wrapped model so that QPyTorch can compute all of the things it needs to cache that can’t be traced. Mostly, this just involves some complex linear algebra operations for variational QEPs.
Additionally, we’ll need to run with the qpytorch.settings.trace_mode setting enabled, because PyTorch can’t trace custom autograd Functions. Note that this results in some inefficiencies, e.g. for variational models we will always compute the full predictive posterior covariance in the traced model. This is not so bad, because we can always just process minibatches of data.
Note: You’ll get a lot of warnings from the tracer. That’s fine. QPyTorch models are pretty large graphs, and include things like .item() calls that you wouldn’t normally encounter in a basic neural network.
[5]:
wrapped_model = MeanVarModelWrapper(model)
with torch.no_grad(), qpytorch.settings.trace_mode():
fake_input = test_x[:1024, :]
pred = wrapped_model(fake_input) # Compute caches
traced_model = torch.jit.trace(wrapped_model, fake_input)
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-06 to the diagonal
warnings.warn(
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-05 to the diagonal
warnings.warn(
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:40: NumericalWarning: A not p.d., added jitter of 1.0e-04 to the diagonal
warnings.warn(
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/qpytorch/variational/variational_strategy.py:246: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not self.updated_strategy.item() and not prior:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/qpytorch/variational/_variational_strategy.py:361: TracerWarning: Converting a tensor to a Python number might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not self.variational_params_initialized.item():
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1015: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
diag = torch.tensor(jitter_val, dtype=self.dtype, device=self.device)
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:966: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not self.is_square:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/operators/diag_linear_operator.py:307: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not (diag_values.dim() and diag_values.size(-1) == 1):
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/gpytorch/kernels/kernel.py:511: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if not x1_.size(-1) == x2_.size(-1):
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:366: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if res.shape != self.shape:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/utils/cholesky.py:21: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if settings.trace_mode.on() or not torch.any(info):
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/qpytorch/variational/variational_strategy.py:208: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if L.shape != induc_induc_covar.shape:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1873: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
other = torch.tensor(other, dtype=self.dtype, device=self.device)
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1883: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if other.numel() == 1:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/linear_operator/utils/broadcasting.py:18: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if n != shape_b[-2]:
/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.12/site-packages/qpytorch/distributions/multivariate_qexponential.py:474: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if variance.lt(min_variance).any():
[6]:
## Compute Errors on a minibatch
mean1 = wrapped_model(test_x[:1024, :])[0]
mean2 = traced_model(test_x[:1024, :])[0]
print(torch.mean(torch.abs(mean1 - test_y[:1024])))
print(torch.mean(torch.abs(mean2 - test_y[:1024])))
tensor(0.0761, grad_fn=<MeanBackward0>)
tensor(0.0761, grad_fn=<MeanBackward0>)
[7]:
traced_model.save('traced_model.pt')