{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Converting Exact QEP Models to TorchScript\n", "\n", "In this notebook, we'll demonstrate converting an Exact QEP model to TorchScript. In general, this is the same as for standard PyTorch models where we'll use `torch.jit.trace`, but there are two pecularities to keep in mind for QPyTorch:\n", "\n", "1. The first time you make predictions with a QPyTorch model (exact or approximate), we cache certain computations. These computations can't be traced, but the results of them can be. Therefore, we'll need to pass data through the untraced model once, and then trace the model.\n", "1. For exact QEPs, we can't trace models unless `qpytorch.settings.fast_pred_var` is used. This is a technical issue that may not be possible to overcome due to limitations on what can be traced in PyTorch; however, if you really need to trace a QEP but can't use the above setting, open an issue so we have visibility on there being demand for this.\n", "1. You can't trace models that return Distribution objects. Therefore, we'll write a simple wrapper than unpacks the MultivariateQExponential that our QEPs return in to just a mean and variance tensor." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define and train an exact QEP\n", "\n", "In the next cell, we define some data, define a QEP model and train it. Nothing new here -- pretty much just move on to the next cell after this one." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The autoreload extension is already loaded. To reload it, use:\n", " %reload_ext autoreload\n" ] } ], "source": [ "import math\n", "import torch\n", "import qpytorch\n", "from matplotlib import pyplot as plt\n", "\n", "%matplotlib inline\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "# Training data is 100 points in [0,1] inclusive regularly spaced\n", "train_x = torch.linspace(0, 1, 100)\n", "# True function is sin(2*pi*x) with Gaussian noise\n", "train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2\n", "\n", "# We will use the simplest form of QEP model, exact inference\n", "POWER = 1.0\n", "class ExactQEPModel(qpytorch.models.ExactQEP):\n", " def __init__(self, train_x, train_y, likelihood):\n", " super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)\n", " self.power = torch.tensor(POWER)\n", " self.mean_module = qpytorch.means.ConstantMean()\n", " self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())\n", " \n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)\n", "\n", "# initialize likelihood and model\n", "likelihood = qpytorch.likelihoods.QExponentialLikelihood()\n", "model = ExactQEPModel(train_x, train_y, likelihood)\n", "\n", "# this is for running the notebook in our testing framework\n", "import os\n", "smoke_test = ('CI' in os.environ)\n", "training_iter = 2 if smoke_test else 50\n", "\n", "\n", "# Find optimal model hyperparameters\n", "model.train()\n", "likelihood.train()\n", "\n", "# Use the adam optimizer\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters\n", "\n", "# \"Loss\" for QEPs - the marginal log likelihood\n", "mll = qpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)\n", "\n", "for i in range(training_iter):\n", " # Zero gradients from previous iteration\n", " optimizer.zero_grad()\n", " # Output from model\n", " output = model(train_x)\n", " # Calc loss and backprop gradients\n", " loss = -mll(output, train_y)\n", " loss.backward()\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Trace the Model\n", "\n", "In the next cell, we trace our QEP model. To overcome the fact that we can't trace Modules that return Distributions, we write a wrapper Module that unpacks the QEP output in to a mean and variance.\n", "\n", "\n", "Additionally, we'll need to run with the `qpytorch.settings.trace_mode` setting enabled, because PyTorch can't trace custom autograd Functions. Note that this results in some inefficiencies.\n", "\n", "Then, before calling `torch.jit.trace` we first call the model on `test_x`. This step is **required**, as it does some precomputation using torch functionality that cannot be traced." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/models/exact_prediction_strategies.py:315: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " if joint_covar.size(-1) <= settings.max_eager_kernel_size.value():\n", "/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/kernels/kernel.py:511: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " if not x1_.size(-1) == x2_.size(-1):\n", "/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py:366: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " if res.shape != self.shape:\n", "/Users/shiweilan/miniconda/envs/qpytorch/lib/python3.10/site-packages/linear_operator/operators/_linear_operator.py:1417: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " elif not self.is_square:\n", "/Users/shiweilan/Projects/QPyTorch/qpytorch/distributions/multivariate_qexponential.py:425: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", " if variance.lt(min_variance).any():\n" ] } ], "source": [ "class MeanVarModelWrapper(torch.nn.Module):\n", " def __init__(self, qep):\n", " super().__init__()\n", " self.qep = qep\n", " \n", " def forward(self, x):\n", " output_dist = self.qep(x)\n", " return output_dist.mean, output_dist.variance\n", "\n", "with torch.no_grad(), qpytorch.settings.fast_pred_var(), qpytorch.settings.trace_mode():\n", " model.eval()\n", " test_x = torch.linspace(0, 1, 51)\n", " pred = model(test_x) # Do precomputation\n", " traced_model = torch.jit.trace(MeanVarModelWrapper(model), test_x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compare Predictions from TorchScript model and Torch model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.)\n", "tensor(0.)\n" ] } ], "source": [ "with torch.no_grad():\n", " traced_mean, traced_var = traced_model(test_x)\n", "\n", "print(torch.norm(traced_mean - pred.mean))\n", "print(torch.norm(traced_var - pred.variance))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "traced_model.save('traced_exact_qep.pt')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 4 }