{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Non-Gaussian Likelihoods\n", "\n", "## Introduction\n", "\n", "This example is the simplest form of using an RBF kernel in an `ApproximateQEP` module for classification. This basic model is usable when there is not much training data and no advanced techniques are required.\n", "\n", "In this example, we’re modeling a unit wave with period 1/2 centered with positive values @ x=0. We are going to classify the points as either +1 or -1.\n", "\n", "Variational inference uses the assumption that the posterior distribution factors multiplicatively over the input variables. This makes approximating the distribution via the KL divergence possible to obtain a fast approximation to the posterior. For a good explanation of variational techniques, sections 4-6 of the following may be useful: https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import math\n", "import torch\n", "import qpytorch\n", "from matplotlib import pyplot as plt\n", "\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up training data\n", "\n", "In the next cell, we set up the training data for this example. We'll be using 10 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels. Labels are unit wave with period 1/2 centered with positive values @ x=0." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train_x = torch.linspace(0, 1, 10)\n", "train_y = torch.sign(torch.cos(train_x * (4 * math.pi))).add(1).div(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setting up the classification model\n", "\n", "The next cell demonstrates the simplest way to define a classification Gaussian process model in QPyTorch. If you have already done the [QEP regression tutorial](../01_Exact_QEPs/Simple_QEP_Regression.ipynb), you have already seen how QPyTorch model construction differs from other QEP packages. In particular, the QEP model expects a user to write out a `forward` method in a way analogous to PyTorch models. This gives the user the most possible flexibility.\n", "\n", "Since exact inference is intractable for QEP classification, QPyTorch approximates the classification posterior using **variational inference.** We believe that variational inference is ideal for a number of reasons. Firstly, variational inference commonly relies on gradient descent techniques, which take full advantage of PyTorch's autograd. This reduces the amount of code needed to develop complex variational models. Additionally, variational inference can be performed with stochastic gradient decent, which can be extremely scalable for large datasets.\n", "\n", "If you are unfamiliar with variational inference, we recommend the following resources:\n", "- [Variational Inference: A Review for Statisticians](https://arxiv.org/abs/1601.00670) by David M. Blei, Alp Kucukelbir, Jon D. McAuliffe.\n", "- [Scalable Variational Gaussian Process Classification](https://arxiv.org/abs/1411.2005) by James Hensman, Alex Matthews, Zoubin Ghahramani.\n", " \n", "In this example, we're using an `UnwhitenedVariationalStrategy` because we are using the training data as inducing points. In general, you'll probably want to use the standard `VariationalStrategy` class for improved optimization." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from qpytorch.models import ApproximateQEP\n", "from qpytorch.variational import CholeskyVariationalDistribution\n", "from qpytorch.variational import UnwhitenedVariationalStrategy\n", "POWER = 1.0\n", "\n", "class QEPClassificationModel(ApproximateQEP):\n", " def __init__(self, train_x):\n", " self.power = torch.tensor(POWER)\n", " variational_distribution = CholeskyVariationalDistribution(train_x.size(0), power=self.power)\n", " variational_strategy = UnwhitenedVariationalStrategy(\n", " self, train_x, variational_distribution, learn_inducing_locations=False\n", " )\n", " super(QEPClassificationModel, self).__init__(variational_strategy)\n", " self.mean_module = qpytorch.means.ConstantMean()\n", " self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())\n", "\n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " latent_pred = qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)\n", " return latent_pred\n", "\n", "\n", "# Initialize model and likelihood\n", "model = QEPClassificationModel(train_x)\n", "likelihood = qpytorch.likelihoods.BernoulliLikelihood()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model modes\n", "\n", "Like most PyTorch modules, the `ApproximateGP` has a `.train()` and `.eval()` mode.\n", "- `.train()` mode is for optimizing variational parameters model hyperameters.\n", "- `.eval()` mode is for computing predictions through the model posterior." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learn the variational parameters (and other hyperparameters)\n", "\n", "In the next cell, we optimize the variational parameters of our q-exponential process.\n", "In addition, this optimization loop also performs Type-II MLE to train the hyperparameters of the q-exponential process." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Iter 1/100 - Loss: 0.482\n", "Iter 2/100 - Loss: 1.555\n", "Iter 3/100 - Loss: 0.807\n", "Iter 4/100 - Loss: 0.801\n", "Iter 5/100 - Loss: 0.823\n", "Iter 6/100 - Loss: 0.729\n", "Iter 7/100 - Loss: 0.636\n", "Iter 8/100 - Loss: 0.606\n", "Iter 9/100 - Loss: 0.614\n", "Iter 10/100 - Loss: 0.613\n", "Iter 11/100 - Loss: 0.601\n", "Iter 12/100 - Loss: 0.587\n", "Iter 13/100 - Loss: 0.572\n", "Iter 14/100 - Loss: 0.554\n", "Iter 15/100 - Loss: 0.532\n", "Iter 16/100 - Loss: 0.509\n", "Iter 17/100 - Loss: 0.491\n", "Iter 18/100 - Loss: 0.480\n", "Iter 19/100 - Loss: 0.473\n", "Iter 20/100 - Loss: 0.470\n", "Iter 21/100 - Loss: 0.467\n", "Iter 22/100 - Loss: 0.464\n", "Iter 23/100 - Loss: 0.460\n", "Iter 24/100 - Loss: 0.457\n", "Iter 25/100 - Loss: 0.454\n", "Iter 26/100 - Loss: 0.451\n", "Iter 27/100 - Loss: 0.449\n", "Iter 28/100 - Loss: 0.447\n", "Iter 29/100 - Loss: 0.446\n", "Iter 30/100 - Loss: 0.445\n", "Iter 31/100 - Loss: 0.445\n", "Iter 32/100 - Loss: 0.444\n", "Iter 33/100 - Loss: 0.444\n", "Iter 34/100 - Loss: 0.443\n", "Iter 35/100 - Loss: 0.442\n", "Iter 36/100 - Loss: 0.442\n", "Iter 37/100 - Loss: 0.441\n", "Iter 38/100 - Loss: 0.440\n", "Iter 39/100 - Loss: 0.439\n", "Iter 40/100 - Loss: 0.438\n", "Iter 41/100 - Loss: 0.437\n", "Iter 42/100 - Loss: 0.437\n", "Iter 43/100 - Loss: 0.437\n", "Iter 44/100 - Loss: 0.436\n", "Iter 45/100 - Loss: 0.436\n", "Iter 46/100 - Loss: 0.436\n", "Iter 47/100 - Loss: 0.436\n", "Iter 48/100 - Loss: 0.436\n", "Iter 49/100 - Loss: 0.436\n", "Iter 50/100 - Loss: 0.435\n", "Iter 51/100 - Loss: 0.435\n", "Iter 52/100 - Loss: 0.435\n", "Iter 53/100 - Loss: 0.435\n", "Iter 54/100 - Loss: 0.434\n", "Iter 55/100 - Loss: 0.434\n", "Iter 56/100 - Loss: 0.434\n", "Iter 57/100 - Loss: 0.434\n", "Iter 58/100 - Loss: 0.434\n", "Iter 59/100 - Loss: 0.434\n", "Iter 60/100 - Loss: 0.433\n", "Iter 61/100 - Loss: 0.433\n", "Iter 62/100 - Loss: 0.433\n", "Iter 63/100 - Loss: 0.433\n", "Iter 64/100 - Loss: 0.433\n", "Iter 65/100 - Loss: 0.433\n", "Iter 66/100 - Loss: 0.433\n", "Iter 67/100 - Loss: 0.433\n", "Iter 68/100 - Loss: 0.433\n", "Iter 69/100 - Loss: 0.433\n", "Iter 70/100 - Loss: 0.433\n", "Iter 71/100 - Loss: 0.432\n", "Iter 72/100 - Loss: 0.432\n", "Iter 73/100 - Loss: 0.432\n", "Iter 74/100 - Loss: 0.432\n", "Iter 75/100 - Loss: 0.432\n", "Iter 76/100 - Loss: 0.432\n", "Iter 77/100 - Loss: 0.432\n", "Iter 78/100 - Loss: 0.432\n", "Iter 79/100 - Loss: 0.432\n", "Iter 80/100 - Loss: 0.432\n", "Iter 81/100 - Loss: 0.432\n", "Iter 82/100 - Loss: 0.432\n", "Iter 83/100 - Loss: 0.432\n", "Iter 84/100 - Loss: 0.432\n", "Iter 85/100 - Loss: 0.432\n", "Iter 86/100 - Loss: 0.432\n", "Iter 87/100 - Loss: 0.432\n", "Iter 88/100 - Loss: 0.432\n", "Iter 89/100 - Loss: 0.432\n", "Iter 90/100 - Loss: 0.432\n", "Iter 91/100 - Loss: 0.432\n", "Iter 92/100 - Loss: 0.432\n", "Iter 93/100 - Loss: 0.432\n", "Iter 94/100 - Loss: 0.432\n", "Iter 95/100 - Loss: 0.432\n", "Iter 96/100 - Loss: 0.432\n", "Iter 97/100 - Loss: 0.431\n", "Iter 98/100 - Loss: 0.431\n", "Iter 99/100 - Loss: 0.431\n", "Iter 100/100 - Loss: 0.431\n" ] } ], "source": [ "# this is for running the notebook in our testing framework\n", "import os\n", "smoke_test = ('CI' in os.environ)\n", "training_iterations = 2 if smoke_test else 100\n", "\n", "\n", "# Find optimal model hyperparameters\n", "model.train()\n", "likelihood.train()\n", "\n", "# Use the adam optimizer\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n", "\n", "# \"Loss\" for QEPs - the marginal log likelihood\n", "# num_data refers to the number of training datapoints\n", "mll = qpytorch.mlls.VariationalELBO(likelihood, model, train_y.numel())\n", "\n", "for i in range(training_iterations):\n", " # Zero backpropped gradients from previous iteration\n", " optimizer.zero_grad()\n", " # Get predictive output\n", " output = model(train_x)\n", " # Calc loss and backprop gradients\n", " loss = -mll(output, train_y)\n", " loss.backward()\n", " print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))\n", " optimizer.step()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make predictions with the model\n", "\n", "In the next cell, we make predictions with the model. To do this, we simply put the model and likelihood in eval mode, and call both modules on the test data.\n", "\n", "In `.eval()` mode, when we call `model()` - we get QEP's latent posterior predictions. These will be MultivariateQExponential distributions. But since we are performing binary classification, we want to transform these outputs to classification probabilities using our likelihood.\n", "\n", "When we call `likelihood(model())`, we get a `torch.distributions.Bernoulli` distribution, which represents our posterior probability that the data points belong to the positive class.\n", "\n", "```python\n", "f_preds = model(test_x)\n", "y_preds = likelihood(model(test_x))\n", "\n", "f_mean = f_preds.mean\n", "f_samples = f_preds.sample(sample_shape=torch.Size((1000,))\n", "```" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXUAAAEYCAYAAACjl2ZMAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAALKhJREFUeJzt3XtcFOe9BvBnuS1guHgBFhQQNYrxhmI0oImXmBJjPdrTY4zxKBovScVUxaRKksYYY0iiiabWlmOMod5qairGo4bEuyd4CwgtVkQRFauAGiMXFRD2d/6wbF257QLDwszz/Xz2Izvzzry/2Zd9nJkdZnUiIiAiIlWws3UBRETUeBjqREQqwlAnIlIRhjoRkYow1ImIVIShTkSkIgx1IiIVYagTEakIQ52ISEUY6kREKqJoqMfGxuLxxx+Hm5sbvL29MXbsWGRmZta53NatWxEcHAxnZ2f06tULu3fvVrJMIiLVUDTUDx06hKioKBw7dgx79uzBvXv38LOf/Qy3b9+ucZkjR45gwoQJmDZtGlJTUzF27FiMHTsWp06dUrJUIiJV0DXlDb2uX78Ob29vHDp0CE899VS1bcaPH4/bt29j586dpmlPPPEEQkJCEBcX11SlEhG1SA5N2VlBQQEAoE2bNjW2OXr0KKKjo82mRUREYPv27dW2Ly0tRWlpqem50WjEzZs30bZtW+h0uoYXTURkYyKCoqIi+Pn5wc6u9hMsTRbqRqMRc+fOxaBBg9CzZ88a2+Xl5cHHx8dsmo+PD/Ly8qptHxsbi8WLFzdqrUREzdHly5fRoUOHWts0WahHRUXh1KlT+P777xt1vTExMWZ79gUFBQgICMDly5fh7u7eqH0REdlCYWEh/P394ebmVmfbJgn12bNnY+fOnTh8+HCd/8sYDAbk5+ebTcvPz4fBYKi2vV6vh16vrzLd3d2doU5EqmLJKWVFr34REcyePRsJCQnYv38/goKC6lwmLCwM+/btM5u2Z88ehIWFKVUmEZFqKLqnHhUVhc2bN+Prr7+Gm5ub6by4h4cHXFxcAACTJ09G+/btERsbCwCYM2cOhgwZgo8//hijRo3Cli1bkJycjDVr1ihZKhGRKii6p/7HP/4RBQUFGDp0KHx9fU2PL7/80tQmJycHubm5pufh4eHYvHkz1qxZgz59+uCrr77C9u3ba/1wlYiI7mvS69SbQmFhITw8PFBQUMBz6tToKioqcO/ePVuXQSrk5ORU4+WK1uRak16nTtRSiQjy8vJw69YtW5dCKmVnZ4egoCA4OTk1aD0MdSILVAa6t7c3XF1d+Ydt1KiMRiOuXr2K3NxcBAQENOj3i6FOVIeKigpToLdt29bW5ZBKeXl54erVqygvL4ejo2O918Nb7xLVofIcuqurq40rITWrPO1SUVHRoPUw1IksxFMupKTG+v1iqBMRqQhDnYjQsWNHrFy50tZlNBq1bY81GOpEKnb58mW89NJL8PPzg5OTEwIDAzFnzhz8+OOPti7Npt555x3odDrodDo4ODigXbt2eOqpp7By5UqzW3lb4uDBg9DpdM3mcleGOlETSk5OxvDhw5GcnKx4X9nZ2ejfvz/OnTuHP//5z8jKykJcXBz27duHsLAw3Lx5U/EaalJRUQGj0Wiz/gGgR48eyM3NRU5ODg4cOIBx48YhNjYW4eHhKCoqsmltDcFQJ2pC69evx4EDB7BhwwbF+4qKioKTkxO+++47DBkyBAEBARg5ciT27t2LK1eu4M033zRrX1RUhAkTJqBVq1Zo3749Vq9ebZonInjnnXcQEBAAvV4PPz8//PrXvzbNLy0txWuvvYb27dujVatWGDhwIA4ePGiaHx8fD09PT+zYsQOPPfYY9Ho91q5dC2dn5yp7uHPmzMHw4cNNz7///ns8+eSTcHFxgb+/P37961+bfSXmtWvXMHr0aLi4uCAoKAibNm2y6PVxcHCAwWCAn58fevXqhVdffRWHDh3CqVOn8OGHH5rabdiwAf3794ebmxsMBgNefPFFXLt2DQBw8eJFDBs2DADQunVr6HQ6TJkyBQCQmJiIwYMHw9PTE23btsXPf/5znD9/3qLaGkRUpqCgQABIQUGBrUshlbh7966cPn1a7t69W6/lL168KMnJyZKSkiLe3t4CQLy9vSUlJUWSk5Pl4sWLjVyxyI8//ig6nU7ef//9aufPmDFDWrduLUajUUREAgMDxc3NTWJjYyUzM1N+97vfib29vXz33XciIrJ161Zxd3eX3bt3y6VLl+T48eOyZs0a0/qmT58u4eHhcvjwYcnKypJly5aJXq+Xs2fPiojIF198IY6OjhIeHi5JSUly5swZKS4uFh8fH1m7dq1pPeXl5WbTsrKypFWrVrJixQo5e/asJCUlSd++fWXKlCmmZUaOHCl9+vSRo0ePSnJysoSHh4uLi4usWLGixtdn0aJF0qdPn2rnjRkzRrp37256/vnnn8vu3bvl/PnzcvToUQkLC5ORI0ea6v3rX/8qACQzM1Nyc3Pl1q1bIiLy1VdfyV//+lc5d+6cpKamyujRo6VXr15SUVFRbb+1/Z5Zk2sMdaI6NDTUAZgeOp3O7N/KR2M7duyYAJCEhIRq53/yyScCQPLz80Xkfqg/++yzZm3Gjx9vCq+PP/5YunbtKmVlZVXWdenSJbG3t5crV66YTX/66aclJiZGRO6HOgBJS0szazNnzhwZPny46fm3334rer1efvrpJxERmTZtmsycOdNsmf/7v/8TOzs7uXv3rmRmZgoAOXHihGl+RkaGAKh3qC9YsEBcXFxqXPaHH34QAFJUVCQiIgcOHBAAppprcv36dQEg6enp1c5vrFDn6RcihW3cuBEODvf/eFv+df+8yn8dHBywceNGxfoWK+7X9/B3FoSFhSEjIwMAMG7cONy9exedOnXCjBkzkJCQgPLycgBAeno6Kioq0LVrVzzyyCOmx6FDh8xONzg5OaF3795mfUycOBEHDx7E1atXAQCbNm3CqFGj4OnpCQD429/+hvj4eLP1RkREwGg04sKFC8jIyICDgwNCQ0NN6wwODjYtXx8iYnbNeEpKCkaPHo2AgAC4ublhyJAhAO7fYbY2586dw4QJE9CpUye4u7ujY8eOFi3XULxNAJHCJk6ciO7du5sFT6Xjx4+jX79+jd5nly5doNPpkJGRgV/84hdV5mdkZKB169bw8vKyaH3+/v7IzMzE3r17sWfPHsyaNQvLli3DoUOHUFxcDHt7e6SkpMDe3t5suUceecT0s4uLS5U/sHn88cfRuXNnbNmyBb/61a+QkJCA+Ph40/zi4mK8/PLLZufvKwUEBODs2bMW1W+NjIwM0xf63L59GxEREYiIiMCmTZvg5eWFnJwcREREoKysrNb1jB49GoGBgfjss8/g5+cHo9GInj171rlcQzHUiZqQnZ0djEaj6V+ltG3bFs888wz+8Ic/YN68eaYvpQHu35xs06ZNmDx5slnIHjt2zGwdx44dQ/fu3U3PXVxcMHr0aIwePRpRUVEIDg5Geno6+vbti4qKCly7dg1PPvmk1bVOnDgRmzZtQocOHWBnZ4dRo0aZ5vXr1w+nT59Gly5dql02ODgY5eXlSElJweOPPw4AyMzMrPflhWfOnEFiYiJiYmJMz3/88Ud88MEH8Pf3B4AqVy5V9+f9P/74IzIzM/HZZ5+ZXpPG/n7mmvD0C1ET8Pb2hsFgQGhoKOLi4hAaGgqDwQBvb2/F+vz973+P0tJSRERE4PDhw7h8+TISExPxzDPPoH379li6dKlZ+6SkJHz00Uc4e/YsVq9eja1bt2LOnDkA7l+98vnnn+PUqVPIzs7Gxo0b4eLigsDAQHTt2hUTJ07E5MmTsW3bNly4cAEnTpxAbGwsdu3aVWedEydOxMmTJ7F06VL813/9l9l3Di9YsABHjhzB7NmzkZaWhnPnzuHrr7/G7NmzAQDdunXDs88+i5dffhnHjx9HSkoKpk+fbvafWE3Ky8uRl5eHq1evIj09HatWrcKQIUMQEhKC119/HcD9owEnJyesWrUK2dnZ2LFjB5YsWWK2nsDAQOh0OuzcuRPXr19HcXExWrdujbZt22LNmjXIysrC/v37ER0dXWdNjaLOs+4tDD8opcbW0A9KK5WUlJiuNjEajVJSUtIY5dXq4sWLEhkZKT4+PuLo6Cj+/v7y6quvyo0bN8zaBQYGyuLFi2XcuHHi6uoqBoNBPv30U9P8hIQEGThwoLi7u0urVq3kiSeekL1795rml5WVydtvvy0dO3YUR0dH8fX1lV/84hfy97//XUTuf1Dq4eFRY50DBgwQALJ///4q806cOCHPPPOMPPLII9KqVSvp3bu3LF261DQ/NzdXRo0aJXq9XgICAmT9+vUSGBhY5wel+NeH1Pb29tKmTRsZPHiwrFixosq4bN68WTp27Ch6vV7CwsJkx44dAkBSU1NNbd59910xGAyi0+kkMjJSRET27Nkj3bt3F71eL71795aDBw/W+uF1Y31Qym8+IqpDSUkJLly4gKCgIDg7O9u6HFKp2n7PrMk1nn4hIlIRhjoRkYow1ImIVIShTkSkIgx1IiIVYagTEakIQ52ISEUY6kREKsJQJyJSEYY6EZGKKBrqhw8fxujRo+Hn5wedToft27fX2r7yC1wffuTl5SlZJpFqTZkyBTqdDq+88kqVeVFRUWZfv0bqoGio3759G3369DH7rkNLZGZmIjc31/RQ8k52RGrn7++PLVu24O7du6ZpJSUl2Lx5MwICAmxYGSlB0VAfOXIk3nvvvWpv0l+bytuUVj7s7HiWiKi++vXrB39/f2zbts00bdu2bQgICEDfvn1N04xGI2JjYxEUFAQXFxf06dMHX331lWl+RUUFpk2bZprfrVs3fPrpp2Z9TZkyBWPHjsXy5cvh6+uLtm3bIioqCvfu3VN+QwlAM/2SjJCQEJSWlqJnz5545513MGjQoBrblpaWorS01PS8sLCwKUokDRMB7tyxTd+ursBDXx5kkZdeeglffPEFJk6cCABYt24dpk6dioMHD5raxMbGYuPGjYiLi8Ojjz6Kw4cP47//+7/h5eWFIUOGwGg0okOHDti6dSvatm2LI0eOYObMmfD19cXzzz9vWs+BAwfg6+uLAwcOICsrC+PHj0dISAhmzJjR0M0nS9R5c95GglruI1zpzJkzEhcXJ8nJyZKUlCRTp04VBwcHSUlJqXGZB++L/OCD91OnxvLwfa6Li0XuR3vTP4qLras9MjJSxowZI9euXRO9Xi8XL16UixcvirOzs1y/fl3GjBkjkZGRUlJSIq6urnLkyBGz5adNmyYTJkyocf1RUVHyy1/+0qy/wMBAKS8vN00bN26cjB8/3rrCNaix7qferPbUu3Xrhm7dupmeh4eH4/z581ixYgU2bNhQ7TIxMTFm3yhSWFho+topIrrPy8sLo0aNQnx8PEQEo0aNQrt27Uzzs7KycOfOHTzzzDNmy5WVlZmdolm9ejXWrVuHnJwc3L17F2VlZQgJCTFbpkePHmbfVerr64v09HRlNoyqaFahXp0BAwbU+t1+er3e7OuviJTm6goUF9uu7/p66aWXTF8D9/DFC8X/2qBdu3ahffv2ZvMq319btmzBa6+9ho8//hhhYWFwc3PDsmXLcPz4cbP2jo6OZs91Op2i38dK5pp9qKelpcHX19fWZRCZ6HRAq1a2rsJ6zz77LMrKyqDT6RAREWE277HHHoNer0dOTg6GDBlS7fJJSUkIDw/HrFmzTNPOnz+vaM1kPUVDvbi4GFlZWabnFy5cQFpaGtq0aYOAgADExMTgypUrWL9+PQBg5cqVCAoKQo8ePVBSUoK1a9di//79+O6775Qsk0gT7O3tkZGRYfr5QW5ubnjttdcwb948GI1GDB48GAUFBUhKSoK7uzsiIyPx6KOPYv369fj2228RFBSEDRs24IcffkBQUJAtNodqoGioJycnY9iwYabnlee+IyMjER8fj9zcXOTk5Jjml5WVYf78+bhy5QpcXV3Ru3dv7N2712wdRFR/tX2/5ZIlS+Dl5YXY2FhkZ2fD09MT/fr1wxtvvAEAePnll5Gamorx48dDp9NhwoQJmDVrFr755pumKp8swC+eJqoDv3iamgK/eJqIiKpgqBMRqQhDnYhIRRjqREQqwlAnIlIRhjqRhfhXkaSkxroQsdn/RSmRrTk5OcHOzg5Xr16Fl5cXnJycoKvPrRKJaiAiuH79OnQ6XZXbLFiLoU5UBzs7OwQFBSE3NxdXr161dTmkUjqdDh06dKjy177WYqgTWcDJyQkBAQEoLy9HRUWFrcshFXJ0dGxwoAMMdSKLVR4aN/TwmEhJ/KCUiEhFGOpERCrCUCciUhGGOhGRijDUiYhUhKFORKQiDHUiIhVhqBMRqQhDnYhIRRjqREQqwlAnIlIRhjoRkYow1ImIVIShTkSkIgx1IiIVYagTEakIQ52ISEUUDfXDhw9j9OjR8PPzg06nw/bt2+tc5uDBg+jXrx/0ej26dOmC+Ph4JUs0SU5OxvDhw5GcnNwk/TW3/rXKlq87x9w21D7miob67du30adPH6xevdqi9hcuXMCoUaMwbNgwpKWlYe7cuZg+fTq+/fZbJcsEAKxfvx4HDhzAhg0bFO+rOfavVbZ83TnmtqH2MdeJiCi29gc70umQkJCAsWPH1thmwYIF2LVrF06dOmWa9sILL+DWrVtITEy0qJ/CwkJ4eHigoKAA7u7utba9dOkSbty4AZ1Oh5EjR+LatWvw9vbGN998AxFBu3btEBgYaFG/9WHr/rXKlq87x9w2WvqYW5NrzSrUn3rqKfTr1w8rV640Tfviiy8wd+5cFBQUVLtMaWkpSktLTc8LCwvh7+9v2cbrdA88WwDg51XaDBo0uNZ1NERS0vd1tmmM/iMigN/+tsGrsamCAmD6dCA3t+HraqrX3VZ9c7yran5jvhBAkulZXTFsTahDmggASUhIqLXNo48+Ku+//77ZtF27dgkAuXPnTrXLLFq0SABUeRQUFNRZ08aNG8XBweFfy6wRQFT50OlE7t61eKiapb/8xfavY0t5cLxbwmOUABAHBwfZuHFjna9HQUGBWJprDpb+b9NcxcTEIDo62vS8ck/dEhMnTkT37t0RGhoK4I8AvjHNW7ZsOTp16tTI1VaVnZ2N119/rcr0xujfaATGjbv/K3TnDuDs3KDV2VRR0f1/+/UD3nyz4etT8nW3Vd8c79o1rzFPAQAcP34c/fr1a9S+mlWoGwwG5Ofnm03Lz8+Hu7s7XFxcql1Gr9dDr9c3uG87u7/BaEyFnZ0djEYjhg9/C438Wlfr5MlbABJM/TZ2/w4OQHk5cPduw9dlS5X1BwUB//mfDV+f0q+7rfrmeNes+Y25Mn01q+vUw8LCsG/fPrNpe/bsQVhYmGJ9ent7w2AwIDQ0FHFxcQgNDYXBYIC3t7difTZl/66u9/9Vy5u8cnsaypbjrmTfHO+aqXXMq2iM8181KSoqktTUVElNTRUA8sknn0hqaqpcunRJREQWLlwokyZNMrXPzs4WV1dXef311yUjI0NWr14t9vb2kpiYaHGf1px7qlRSUiJGo1FERIxGo5SUlFi8bGNQsn8fn/vn8P72t0ZbpU0sXnx/O2bObLx12nLcleqb4127ljrmzeacenJyMoYNG2Z6XnnuOzIyEvHx8cjNzUVOTo5pflBQEHbt2oV58+bh008/RYcOHbB27VpEREQoWabZ6RudTtcop3OaS/+VZ63UsudWw1m4erHluCvVN8e7dmoc84cpGupDhw6FiNQ4v7q/Fh06dChSU1MVrEpbeDiuLRxvalbn1KnxVe7p3Llj2zoaqrL+xt5zUxuONzHUVY6H49rC8SaGuspVHr6qZc+Nh+O143gTQ13luOemLRxvYqirHD840xaONzHUVY4fnGkLx5sY6irHw3Ft4XgTQ13leDiuLRxvYqirHA/HtYXjTQx1lePhuLZwvImhrnK8bllbON7EUFc5Ney5GY1A5TcWcs+tdhxvYqirnBo+OCsp+ffP3HOrHcebGOoqp4YPzh6snXtuteN4E0Nd5dRwOF5Zu6MjYG9v21qaO443MdRVTg2H47xm2XIcb2Koq5yaDsd5KF43jjcx1FVOTYfjfJPXjeNNDHWV4+G4tnC8iaGucjwc1xaONzHUVa7yjVFRAdy7Z9ta6ouH45bjeBNDXeUePIRtqXtv/JNxy3G8iaGuck5OgE53/+eWep6Ve26W43gTQ13ldLqW/+EZPzizHMebGOoa0NI/POMHZ9bheGsbQ10DWvq1yzwctw7HW9sY6hrAw3Ft4XhrW5OE+urVq9GxY0c4Oztj4MCBOHHiRI1t4+PjodPpzB7Ozs5NUaZq8XBcWzje2qZ4qH/55ZeIjo7GokWLcPLkSfTp0wcRERG4du1ajcu4u7sjNzfX9Lh06ZLSZaoaD8e1heOtbYqH+ieffIIZM2Zg6tSpeOyxxxAXFwdXV1esW7euxmV0Oh0MBoPp4ePjo3SZqtbSv+KM1y1bh+OtbYqGellZGVJSUjBixIh/d2hnhxEjRuDo0aM1LldcXIzAwED4+/tjzJgx+Mc//qFkmarHPTdt4Xhrm6KhfuPGDVRUVFTZ0/bx8UFeXl61y3Tr1g3r1q3D119/jY0bN8JoNCI8PBz//Oc/q21fWlqKwsJCsweZ4wdn2sLx1rZmd/VLWFgYJk+ejJCQEAwZMgTbtm2Dl5cX/ud//qfa9rGxsfDw8DA9/P39m7ji5o8fnGkLx1vbFA31du3awd7eHvn5+WbT8/PzYTAYLFqHo6Mj+vbti6ysrGrnx8TEoKCgwPS4fPlyg+tWGx6OawvHW9sUDXUnJyeEhoZi3759pmlGoxH79u1DWFiYReuoqKhAeno6fH19q52v1+vh7u5u9iBzPBzXFo63tjko3UF0dDQiIyPRv39/DBgwACtXrsTt27cxdepUAMDkyZPRvn17xMbGAgDeffddPPHEE+jSpQtu3bqFZcuW4dKlS5g+fbrSpaoWD8e1heOtbYqH+vjx43H9+nW8/fbbyMvLQ0hICBITE00fnubk5MDO7t8HDD/99BNmzJiBvLw8tG7dGqGhoThy5Agee+wxpUtVLR6OawvHW9sUD3UAmD17NmbPnl3tvIMHD5o9X7FiBVasWNEEVWkHD8e1heOtbc3u6hdqfC35cFyEh+PW4nhrG0NdA1ry4XhZ2f03OsA3uaU43trGUNeAlvxn4w/WzMNxy3C8tY2hrgEtec+tsmY7O8DR0ba1tBQcb21jqGtAS/7g7MEPzSq/e5Nqx/HWNoa6BrTkD874oZn1ON7axlDXADUcjvNNbjmOt7Yx1DVALYfjZBmOt7Yx1DWAh+PawvHWNoa6BlS+QcrKgIoK29ZiLR6OW4/jrW0MdQ148FC2pR2S86vNrMfx1jaGugY4O//755b2Jueem/U43trGUNcAO7t/v9Fb6puce26W43hrG0NdI1rqh2f84Kx+ON7axVDXiJZ67TIPx+uH461dDHWNaKnXLvNwvH443trFUNcIHo5rC8dbuxjqGsHDcW3heGsXQ10jWuo9tnndcv1wvLWLoa4R3HPTFo63djHUNYJvcm3heGsXQ10jeDiuLRxv7WKoawT33LSF461dDHWN4HXL2sLx1i6GukbwumVt4XhrF0NdI3g4ri0cb+1iqGsED8e1heOtXQx1jeDhuLZwvLWrSUJ99erV6NixI5ydnTFw4ECcOHGi1vZbt25FcHAwnJ2d0atXL+zevbspylS12g7Hk5OTMXz4cCQnJzdtUXX0fe8eUF5+/2e+ya1T1+kXW455Tf1zvBuH4qH+5ZdfIjo6GosWLcLJkyfRp08fRERE4Nq1a9W2P3LkCCZMmIBp06YhNTUVY8eOxdixY3Hq1CmlS1W12q5bXr9+PQ4cOIANGzY0bVF19P1gIPFw3Dp1XaduyzGvqX+Od+PQiYgo2cHAgQPx+OOP4/e//z0AwGg0wt/fH6+++ioWLlxYpf348eNx+/Zt7Ny50zTtiSeeQEhICOLi4ursr7CwEB4eHigoKIC7u3vjbUgL9/XXwNixwBNPAEePApcuXcKNGzeg0+kwcuRIXLt2Dd7e3vjmm28gImjXrh0CAwMVqcXSvvPzAYPh/jJGI6DTKVKOKj083oBtx9yS/kW88fjj/gA43g+zJtcclCykrKwMKSkpiImJMU2zs7PDiBEjcLTyN+0hR48eRXR0tNm0iIgIbN++vdr2paWlKC0tNT0vLCxseOEqVLnnk54ODB4MJCVdfmDuXwEA164BoaGVu3aXMWiQMm9wS/uuHFYXF77BrfXweAO2HXPL+s8D4M/xbiBFQ/3GjRuoqKiAj4+P2XQfHx+cOXOm2mXy8vKqbZ+Xl1dt+9jYWCxevLhxClaxjh3v/3v7NpCUBACD61zmfjslWNd3UJBSdahX1fEGbDvmlvUPcLwbStFQbwoxMTFme/aFhYXw9/e3YUXN06OPAidPAhcu/HtadnY2Xn/9tSptly1bjk6dOilajzV9h4UpWooqVTfegG3H3NL+Od4No2iot2vXDvb29sjPzzebnp+fD0PlydKHGAwGq9rr9Xro9frGKVjl+va9/6h08uQtAAmws7OD0Wg0/Tt8+Fvo10/ZWmzZt1Y8PN6A7V93W/evBYpe/eLk5ITQ0FDs27fPNM1oNGLfvn0Iq+G/47CwMLP2ALBnz54a21P9eXt7w2AwIDQ0FHFxcQgNDYXBYIC3t7eq+9YyW7/utu5fE0RhW7ZsEb1eL/Hx8XL69GmZOXOmeHp6Sl5enoiITJo0SRYuXGhqn5SUJA4ODrJ8+XLJyMiQRYsWiaOjo6Snp1vUX0FBgQCQgoICRbZHbUpKSsRoNIqIiNFolJKSEk30rWW2ft1t3X9LZE2uKX5Offz48bh+/Trefvtt5OXlISQkBImJiaYPQ3NycmBn9+8DhvDwcGzevBlvvfUW3njjDTz66KPYvn07evbsqXSpmvTgqSudTtekp7Js2beW2fp1t3X/aqf4depNjdepE5HaWJNrvPcLEZGKMNSJiFSEoU5EpCIMdSIiFWGoExGpCEOdiEhFGOpERCrCUCciUhGGOhGRijDUiYhUhKFORKQiDHUiIhVhqBMRqQhDnYhIRRjqREQqwlAnIlIRhjoRkYow1ImIVIShTkSkIgx1IiIVYagTEakIQ52ISEUY6kREKsJQJyJSEYY6EZGKMNSJiFSEoU5EpCIMdSIiFVE01G/evImJEyfC3d0dnp6emDZtGoqLi2tdZujQodDpdGaPV155RckyiYhUw0HJlU+cOBG5ubnYs2cP7t27h6lTp2LmzJnYvHlzrcvNmDED7777rum5q6urkmUSEamGYqGekZGBxMRE/PDDD+jfvz8AYNWqVXjuueewfPly+Pn51bisq6srDAaDUqUREamWYqdfjh49Ck9PT1OgA8CIESNgZ2eH48eP17rspk2b0K5dO/Ts2RMxMTG4c+dOjW1LS0tRWFho9iAi0irF9tTz8vLg7e1t3pmDA9q0aYO8vLwal3vxxRcRGBgIPz8//P3vf8eCBQuQmZmJbdu2Vds+NjYWixcvbtTaiYhaKqtDfeHChfjwww9rbZORkVHvgmbOnGn6uVevXvD19cXTTz+N8+fPo3PnzlXax8TEIDo62vS8sLAQ/v7+9e6fiKglszrU58+fjylTptTaplOnTjAYDLh27ZrZ9PLycty8edOq8+UDBw4EAGRlZVUb6nq9Hnq93uL1ERGpmdWh7uXlBS8vrzrbhYWF4datW0hJSUFoaCgAYP/+/TAajaagtkRaWhoAwNfX19pSiYg0R7EPSrt3745nn30WM2bMwIkTJ5CUlITZs2fjhRdeMF35cuXKFQQHB+PEiRMAgPPnz2PJkiVISUnBxYsXsWPHDkyePBlPPfUUevfurVSpRESqoegfH23atAnBwcF4+umn8dxzz2Hw4MFYs2aNaf69e/eQmZlpurrFyckJe/fuxc9+9jMEBwdj/vz5+OUvf4n//d//VbJMIiLV0ImI2LqIxlRYWAgPDw8UFBTA3d3d1uUQETWYNbnGe78QEakIQ52ISEUY6kREKsJQJyJSEYY6EZGKMNSJiFSEoU5EpCIMdSIiFWGoExGpCEOdiEhFGOpERCrCUCciUhGGOhGRijDUiYhUhKFORKQiDHUiIhVhqBMRqQhDnYhIRRjqREQqwlAnIlIRhjoRkYow1ImIVIShTkSkIgx1IiIVYagTEakIQ52ISEUY6kREKqJYqC9duhTh4eFwdXWFp6enRcuICN5++234+vrCxcUFI0aMwLlz55QqkYhIdRQL9bKyMowbNw6/+tWvLF7mo48+wu9+9zvExcXh+PHjaNWqFSIiIlBSUqJUmUREqqITEVGyg/j4eMydOxe3bt2qtZ2IwM/PD/Pnz8drr70GACgoKICPjw/i4+PxwgsvWNRfYWEhPDw8UFBQAHd394aWT0Rkc9bkmkMT1VSnCxcuIC8vDyNGjDBN8/DwwMCBA3H06NEaQ720tBSlpaWm5wUFBQDuvwhERGpQmWeW7IM3m1DPy8sDAPj4+JhN9/HxMc2rTmxsLBYvXlxlur+/f+MWSERkY0VFRfDw8Ki1jVWhvnDhQnz44Ye1tsnIyEBwcLA1q22QmJgYREdHm54bjUbcvHkTbdu2hU6ns3g9hYWF8Pf3x+XLl1V72kbt26j27QPUv43cvuqJCIqKiuDn51dnW6tCff78+ZgyZUqtbTp16mTNKk0MBgMAID8/H76+vqbp+fn5CAkJqXE5vV4PvV5vNs3Sq22q4+7urspfpgepfRvVvn2A+reR21dVXXvolawKdS8vL3h5eVlViKWCgoJgMBiwb98+U4gXFhbi+PHjVl1BQ0SkZYpd0piTk4O0tDTk5OSgoqICaWlpSEtLQ3FxsalNcHAwEhISAAA6nQ5z587Fe++9hx07diA9PR2TJ0+Gn58fxo4dq1SZRESqotgHpW+//Tb+9Kc/mZ737dsXAHDgwAEMHToUAJCZmWm6WgUAfvOb3+D27duYOXMmbt26hcGDByMxMRHOzs5KlWmi1+uxaNGiKqdy1ETt26j27QPUv43cvoZT/Dp1IiJqOrz3CxGRijDUiYhUhKFORKQiDHUiIhXRVKivXr0aHTt2hLOzMwYOHIgTJ07U2n7r1q0IDg6Gs7MzevXqhd27dzdRpfVnzTZ+9tlnePLJJ9G6dWu0bt0aI0aMqPM1sTVrx7DSli1boNPpmv3lsdZu361btxAVFQVfX1/o9Xp07dq12f+eWruNK1euRLdu3eDi4gJ/f3/Mmzev2d659fDhwxg9ejT8/Pyg0+mwffv2Opc5ePAg+vXrB71ejy5duiA+Pr5hRYhGbNmyRZycnGTdunXyj3/8Q2bMmCGenp6Sn59fbfukpCSxt7eXjz76SE6fPi1vvfWWODo6Snp6ehNXbjlrt/HFF1+U1atXS2pqqmRkZMiUKVPEw8ND/vnPfzZx5ZaxdvsqXbhwQdq3by9PPvmkjBkzpmmKrQdrt6+0tFT69+8vzz33nHz//fdy4cIFOXjwoKSlpTVx5Zazdhs3bdoker1eNm3aJBcuXJBvv/1WfH19Zd68eU1cuWV2794tb775pmzbtk0ASEJCQq3ts7OzxdXVVaKjo+X06dOyatUqsbe3l8TExHrXoJlQHzBggERFRZmeV1RUiJ+fn8TGxlbb/vnnn5dRo0aZTRs4cKC8/PLLitbZENZu48PKy8vFzc1N/vSnPylVYoPUZ/vKy8slPDxc1q5dK5GRkc061K3dvj/+8Y/SqVMnKSsra6oSG8zabYyKipLhw4ebTYuOjpZBgwYpWmdjsCTUf/Ob30iPHj3Mpo0fP14iIiLq3a8mTr+UlZUhJSXF7La+dnZ2GDFiBI4ePVrtMkePHjVrDwARERE1tre1+mzjw+7cuYN79+6hTZs2SpVZb/XdvnfffRfe3t6YNm1aU5RZb/XZvh07diAsLAxRUVHw8fFBz5498f7776OioqKpyrZKfbYxPDwcKSkpplM02dnZ2L17N5577rkmqVlpSuRMs7n1rpJu3LiBioqKam/re+bMmWqXycvLs/o2wLZUn2182IIFC+Dn51fll6w5qM/2ff/99/j888+RlpbWBBU2TH22Lzs7G/v378fEiROxe/duZGVlYdasWbh37x4WLVrUFGVbpT7b+OKLL+LGjRsYPHgwRATl5eV45ZVX8MYbbzRFyYqrKWcKCwtx9+5duLi4WL1OTeypU90++OADbNmyBQkJCU1yWwalFRUVYdKkSfjss8/Qrl07W5ejCKPRCG9vb6xZswahoaEYP3483nzzTcTFxdm6tEZz8OBBvP/++/jDH/6AkydPYtu2bdi1axeWLFli69KaLU3sqbdr1w729vbIz883m56fn2+65e/DDAaDVe1trT7bWGn58uX44IMPsHfvXvTu3VvJMuvN2u07f/48Ll68iNGjR5umGY1GAICDgwMyMzPRuXNnZYu2Qn3Gz9fXF46OjrC3tzdN6969O/Ly8lBWVgYnJydFa7ZWfbbxt7/9LSZNmoTp06cDAHr16mW6P9Sbb74JO7uWvV9aU864u7vXay8d0MieupOTE0JDQ7Fv3z7TNKPRiH379iEsLKzaZcLCwszaA8CePXtqbG9r9dlG4P6XfS9ZsgSJiYno379/U5RaL9ZuX3BwMNLT0013B01LS8N//Md/YNiwYUhLS2t234xVn/EbNGgQsrKyTP9ZAcDZs2fh6+vb7AIdqN823rlzp0pwV/4nJiq4bZUiOVPvj1hbmC1btoher5f4+Hg5ffq0zJw5Uzw9PSUvL09ERCZNmiQLFy40tU9KShIHBwdZvny5ZGRkyKJFi1rEJY3WbOMHH3wgTk5O8tVXX0lubq7pUVRUZKtNqJW12/ew5n71i7Xbl5OTI25ubjJ79mzJzMyUnTt3ire3t7z33nu22oQ6WbuNixYtEjc3N/nzn/8s2dnZ8t1330nnzp3l+eeft9Um1KqoqEhSU1MlNTVVAMgnn3wiqampcunSJRERWbhwoUyaNMnUvvKSxtdff10yMjJk9erVvKTRGqtWrZKAgABxcnKSAQMGyLFjx0zzhgwZIpGRkWbt//KXv0jXrl3FyclJevToIbt27Wriiq1nzTYGBgYKgCqPRYsWNX3hFrJ2DB/U3ENdxPrtO3LkiAwcOFD0er106tRJli5dKuXl5U1ctXWs2cZ79+7JO++8I507dxZnZ2fx9/eXWbNmyU8//dT0hVvgwIED1b6nKrcpMjJShgwZUmWZkJAQcXJykk6dOskXX3zRoBp4610iIhXRxDl1IiKtYKgTEakIQ52ISEUY6kREKsJQJyJSEYY6EZGKMNSJiFSEoU5EpCIMdSIiFWGoExGpCEOdiEhFGOpERCry//2b7HisjFbOAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Go into eval mode\n", "model.eval()\n", "likelihood.eval()\n", "\n", "with torch.no_grad():\n", " # Test x are regularly spaced by 0.01 0,1 inclusive\n", " test_x = torch.linspace(0, 1, 101)\n", " # Get classification predictions\n", " observed_pred = likelihood(model(test_x))\n", "\n", " # Initialize fig and axes for plot\n", " f, ax = plt.subplots(1, 1, figsize=(4, 3))\n", " ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n", " # Get the predicted labels (probabilites of belonging to the positive class)\n", " # Transform these probabilities to be 0/1 labels\n", " pred_labels = observed_pred.mean.ge(0.5).float()\n", " ax.plot(test_x.numpy(), pred_labels.numpy(), 'b')\n", " ax.set_ylim([-1, 2])\n", " ax.legend(['Observed Data', 'Mean'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Notes on other Non-Gaussian Likeihoods\n", "\n", "The Bernoulli likelihood is special in that we can compute the analytic (approximate) posterior predictive in closed form. That is: $q(\\mathbf y) = E_{q(\\mathbf f)}[ p(y \\mid \\mathbf f) ]$ is a Bernoulli distribution when $q(\\mathbf f)$ is a multivariate Gaussian.\n", "\n", "Most other non-Gaussian likelihoods do not admit an analytic (approximate) posterior predictive. To that end, calling `likelihood(model)` will generally return Monte Carlo samples from the posterior predictive." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Type of output: Bernoulli\n", "Shape of output: torch.Size([101])\n" ] } ], "source": [ "# Analytic marginal\n", "likelihood = qpytorch.likelihoods.BernoulliLikelihood()\n", "observed_pred = likelihood(model(test_x))\n", "print(\n", " f\"Type of output: {observed_pred.__class__.__name__}\\n\"\n", " f\"Shape of output: {observed_pred.batch_shape + observed_pred.event_shape}\"\n", ")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Type of output: Beta\n", "Shape of output: torch.Size([15, 101])\n" ] } ], "source": [ "# Monte Carlo marginal\n", "likelihood = qpytorch.likelihoods.BetaLikelihood()\n", "with qpytorch.settings.num_likelihood_samples(15):\n", " observed_pred = likelihood(model(test_x))\n", "print(\n", " f\"Type of output: {observed_pred.__class__.__name__}\\n\"\n", " f\"Shape of output: {observed_pred.batch_shape + observed_pred.event_shape}\"\n", ")\n", "# There are 15 MC samples for each test datapoint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See [the Likelihood documentation](https://qepytorch.readthedocs.io/en/stable/likelihoods.html#likelihood) for more details." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 4 }