{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using Pòlya-Gamma Auxiliary Variables for Binary Classification\n", "\n", "## Overview\n", "\n", "In this notebook, we'll demonstrate how to use Pòlya-Gamma auxiliary variables to do efficient inference for Gaussian Process binary classification as in reference [1]. \n", "We will also use natural gradient descent, as described in more detail in the [Natural gradient descent](./Natural_Gradient_Descent.ipynb) tutorial.\n", "\n", "\n", "[1] Florian Wenzel, Theo Galy-Fajou, Christan Donner, Marius Kloft, Manfred Opper. [Efficient Gaussian process classification using Pòlya-Gamma data augmentation](https://arxiv.org/abs/1802.06383). Proceedings of the AAAI Conference on Artificial Intelligence. 2019.\n", "\n", "## Pòlya-Gamma Augmentation\n", "\n", "When a Q-Exponential Process prior is paired with a q-exponential likelihood inference can be done exactly with a simple closed form expression.\n", "Unfortunately this attractive feature does not carry over to non-conjugate likelihoods like the Bernoulli likelihood that arises in the context of binary classification with a logistic link function.\n", "Sampling-based stochastic variational inference offers a general strategy for dealing with non-conjugate likelihoods; see the [corresponding tutorial](./Non_Gaussian_Likelihoods.ipynb).\n", "\n", "Another possible strategy is to introduce additional latent variables that restore conjugacy. \n", "This is the strategy we follow here. \n", "In particular we are going to introduce a Pòlya-Gamma auxiliary variable for each data point in our training dataset. \n", "The [Polya-Gamma](https://arxiv.org/abs/1205.0310) distribution $\\rm{PG}$ is a univariate distribution with support on the positive real line. \n", "In our context it is interesting because if $\\omega_i$ is distributed according to $\\rm{PG}(1,0)$ then the logistic likelihood $\\sigma(\\cdot)$ for data point $(x_i, y_i)$ can be represented as\n", "\n", "\\begin{align}\n", "\\sigma(y_i f_i) = \\frac{1}{1 + \\exp(-y_i f_i)} = \\tfrac{1}{2} \\mathbb{E}_{\\omega_i \\sim \\rm{PG}(1,0)} \\left[ \\exp \\left(\\tfrac{1}{2} y_i f_i - \\tfrac{\\omega_i}{2} f_i^2 \\right) \\right]\n", "\\end{align}\n", "\n", "where $y_i \\in \\{-1, 1\\}$ is the binary label of data point $i$\n", "and $f_i$ is the Q-Exponential Process prior evaluated at input $x_i$. \n", "The crucial point here is that $f_i$ appears quadratically in the exponential within the expectation. \n", "In other words, conditioned on $\\omega_i$, we can integrate out $f_i$ exactly, just as if we were doing regression with a Gaussian likelihood. For more details please see the original reference. \n", "\n", "## Setup" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import tqdm\n", "import math\n", "import torch\n", "import qpytorch\n", "from matplotlib import pyplot as plt\n", "\n", "# Make plots inline\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this example notebook, we'll create a simple artificial dataset." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import os\n", "from math import floor\n", "\n", "# this is for running the notebook in our testing framework\n", "smoke_test = ('CI' in os.environ)\n", "\n", "N = 100\n", "X = torch.linspace(-1., 1., N)\n", "probs = (torch.sin(X * math.pi).add(1.).div(2.))\n", "y = torch.distributions.Bernoulli(probs=probs).sample()\n", "X = X.unsqueeze(-1)\n", "\n", "train_n = int(floor(0.8 * N))\n", "indices = torch.randperm(N)\n", "train_x = X[indices[:train_n]].contiguous()\n", "train_y = y[indices[:train_n]].contiguous()\n", "\n", "test_x = X[indices[train_n:]].contiguous()\n", "test_y = y[indices[train_n:]].contiguous()\n", "\n", "if torch.cuda.is_available():\n", " train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's plot our artificial dataset. \n", "Note that here the binary labels are 0/1-valued; we will need to be careful to translate between this representation and the -1/1 representation that is most natural in the context of Pòlya-Gamma augementation." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAKkVJREFUeJzt3X10VPWB//FPEsgECjOBE5MApiBqUeQhCiYNVak1a6gW7XH3LGIrD2txZam1pLWQCgSkNSCU5aygdDk8eH5bC+LR6i4YW7Pm19WmsgZolaeKgqRKgkidCUGDJN/fH/4yMmSS3HszM98kvF/nzDlw833+3jvzOZOZmyRjjBEAAIAlybYHAAAALmyEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABW9bI9ACeam5v1wQcfqH///kpKSrI9HAAA4IAxRvX19Ro8eLCSk9t+/6NbhJEPPvhAOTk5tocBAAA8qKmp0cUXX9zmz7tFGOnfv7+kzyfj9/stjwYAADgRCoWUk5MTfh1vS7cIIy2/mvH7/YQRAAC6mY4+YsEHWAEAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWdYubnnU3Tc1GOw+f1PH6T5XZP015lwxUSnLHf1PHa714tZPoti9kiV5XL/1FqyMpbuesk7bPrzdu6ABVv/e3dtuJViZW12c8r2En84i2Zl7qOVmjWO2Z17adzt/m+ej1nE3k83VXeU53HUZ+//vfa8WKFaqurtaxY8f03HPP6dvf/na7dSorK1VcXKy9e/cqJydHCxYs0IwZMzwOuWsrf+uYlvznPh0Lfho+NiiQptLJIzVp1KCY14tXO4lu+0KW6HX10l+0Oul9e0uSPj79WafG7bXtaPWSk6Rmo3bbOb9MrK7PeF7DTuYRrYzXeh2tUSz3zGvbTuZh+3z0cs4m8vm6Kz2nJxljTMfFvvDiiy/qtdde07hx43THHXd0GEYOHz6sUaNG6b777tP3vvc9VVRU6Ic//KG2b9+uoqIiR32GQiEFAgEFg8EufTv48reOafZ/7NL5C9qSMZ/47jVRN9hrvXi1k+i2L2SJXlcv/bVVJ5pYnbMdtS3JcT037Xq9PtsaTzzXI5G8rr3bNXLTthO2z0cvc0vU87WTMcfiucfp67frMBJROSmpwzAyb948bd++XW+99Vb42J133qmPP/5Y5eXljvrpDmGkqdnouuX/HZEwz5UkKTuQplfnfSPiLTCv9WLVvxPxbPtCluh19dJfR3U6M26vbWf5fZKSVBtyXs9Ju16vz47GE8/1SKSWeRhjVBtqdFXP6Rq5bdtp/zbPRy/nbCKer9tb61g+9zh9/Y77B1irqqpUWFgYcayoqEhVVVVt1mlsbFQoFIp4dHU7D59s9yQ2ko4FP9XOwydjUi9W/TsRz7YvZIleVy/9dVTHaTtextNW27WhxpgGkZZ2vV6fHY0nnuuRSC3zcBsW3KxRrIPIuW3bOh+9nLOJeL5ub61tPKfHPYzU1tYqKysr4lhWVpZCoZA++eSTqHXKysoUCATCj5ycnHgPs9OO1zs70c4v57VerPqPZR0vbV/IEr2uXvrrTN+xOmcTyev16aVttz9H53TH9Y/n83Wi2+pIl/xqb0lJiYLBYPhRU1Nje0gdyuyf5qmc13qx6j+Wdby0fSFL9Lp66a8zfcfqnE0kr9enl7bd/hyd0x3XP57P14luqyNxDyPZ2dmqq6uLOFZXVye/368+ffpErePz+eT3+yMeXV3eJQM1KJCmtn67lqTPP6Xc8hWxztaLVf9OxLPtC1mi19VLfx3VcdqOl/G01Xa236dsv7t6Ttr1en12NJ54rkcitcwj2++L+Z55bdtp/zbPRy/nbCKer9tbaxvP6XEPIwUFBaqoqIg49rvf/U4FBQXx7jqhUpKTVDp5pCS12uCW/5dOHtnqw0Be68Wqfyfi2faFLNHr6qW/9upEE6tztr22F992lRbf5rye03a9Xp/tjSee65FI585j8W1XRRxzUs/pGrlp2wnb56OXczZRz9dtrbWt53TXYeTUqVPas2eP9uzZI+nzr+7u2bNHR48elfT5r1imTZsWLn/ffffp3Xff1U9+8hMdOHBAjz/+uJ5++mnNnTs3NjPoQiaNGqQnvnuNsgORb21lB9La/ZqU13rxaifRbV/IEr2uXvprq056397h+y94HXdbbQ/ooO226p3/3BmtnfPLxOL6jPc1HG2tnczVyfydtO1krh2dD51ZRydj7Irno5dzNlHP113tOd31V3srKyt14403tjo+ffp0bd68WTNmzNCRI0dUWVkZUWfu3Lnat2+fLr74Yi1cuNDVTc+6w1d7z8UdWOEWd2DlDqxu14M7sHIHVi/zSPQdWBNyn5FE6W5hBAAAdKH7jAAAALSHMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwylMYWbt2rYYNG6a0tDTl5+dr586d7ZZfvXq1RowYoT59+ignJ0dz587Vp59+6mnAAACgZ3EdRrZu3ari4mKVlpZq165dGjt2rIqKinT8+PGo5Z966inNnz9fpaWl2r9/vzZs2KCtW7fqpz/9aacHDwAAuj/XYWTVqlWaNWuWZs6cqZEjR2rdunXq27evNm7cGLX8H/7wB33ta1/TXXfdpWHDhunmm2/W1KlTO3w3BQAAXBhchZEzZ86ourpahYWFXzSQnKzCwkJVVVVFrTNhwgRVV1eHw8e7776rHTt26JZbbmmzn8bGRoVCoYgHAADomXq5KXzixAk1NTUpKysr4nhWVpYOHDgQtc5dd92lEydO6LrrrpMxRmfPntV9993X7q9pysrKtGTJEjdDAwAA3VTcv01TWVmpRx55RI8//rh27dqlZ599Vtu3b9fSpUvbrFNSUqJgMBh+1NTUxHuYAADAElfvjGRkZCglJUV1dXURx+vq6pSdnR21zsKFC3X33Xfre9/7niRp9OjRamho0L333quHHnpIycmt85DP55PP53MzNAAA0E25emckNTVV48aNU0VFRfhYc3OzKioqVFBQELXO6dOnWwWOlJQUSZIxxu14AQBAD+PqnRFJKi4u1vTp0zV+/Hjl5eVp9erVamho0MyZMyVJ06ZN05AhQ1RWViZJmjx5slatWqWrr75a+fn5OnTokBYuXKjJkyeHQwkAALhwuQ4jU6ZM0YcffqhFixaptrZWubm5Ki8vD3+o9ejRoxHvhCxYsEBJSUlasGCB3n//fV100UWaPHmyfv7zn8duFgAAoNtKMt3gdyWhUEiBQEDBYFB+v9/2cAAAgANOX7/52zQAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqzyFkbVr12rYsGFKS0tTfn6+du7c2W75jz/+WHPmzNGgQYPk8/n0la98RTt27PA0YAAA0LP0clth69atKi4u1rp165Sfn6/Vq1erqKhIBw8eVGZmZqvyZ86c0d/93d8pMzNTzzzzjIYMGaL33ntP6enpsRg/AADo5pKMMcZNhfz8fF177bVas2aNJKm5uVk5OTm6//77NX/+/Fbl161bpxUrVujAgQPq3bu3p0GGQiEFAgEFg0H5/X5PbQAAgMRy+vrt6tc0Z86cUXV1tQoLC79oIDlZhYWFqqqqilrnhRdeUEFBgebMmaOsrCyNGjVKjzzyiJqamtrsp7GxUaFQKOIBAAB6Jldh5MSJE2pqalJWVlbE8aysLNXW1kat8+677+qZZ55RU1OTduzYoYULF+oXv/iFfvazn7XZT1lZmQKBQPiRk5PjZpgAAKAbifu3aZqbm5WZmal///d/17hx4zRlyhQ99NBDWrduXZt1SkpKFAwGw4+ampp4DxMAAFji6gOsGRkZSklJUV1dXcTxuro6ZWdnR60zaNAg9e7dWykpKeFjV155pWpra3XmzBmlpqa2quPz+eTz+dwMDQAAdFOu3hlJTU3VuHHjVFFRET7W3NysiooKFRQURK3zta99TYcOHVJzc3P42F/+8hcNGjQoahABAAAXFte/pikuLtb69ev15JNPav/+/Zo9e7YaGho0c+ZMSdK0adNUUlISLj979mydPHlSDzzwgP7yl79o+/bteuSRRzRnzpzYzQIAAHRbru8zMmXKFH344YdatGiRamtrlZubq/Ly8vCHWo8ePark5C8yTk5Ojl566SXNnTtXY8aM0ZAhQ/TAAw9o3rx5sZsFAADotlzfZ8QG7jMCAED3E5f7jAAAAMQaYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVZ7CyNq1azVs2DClpaUpPz9fO3fudFRvy5YtSkpK0re//W0v3QIAgB7IdRjZunWriouLVVpaql27dmns2LEqKirS8ePH26135MgR/fjHP9b111/vebAAAKDncR1GVq1apVmzZmnmzJkaOXKk1q1bp759+2rjxo1t1mlqatJ3vvMdLVmyRMOHD+/UgAEAQM/iKoycOXNG1dXVKiws/KKB5GQVFhaqqqqqzXoPP/ywMjMzdc899zjqp7GxUaFQKOIBAAB6Jldh5MSJE2pqalJWVlbE8aysLNXW1kat8+qrr2rDhg1av369437KysoUCATCj5ycHDfDBAAA3Uhcv01TX1+vu+++W+vXr1dGRobjeiUlJQoGg+FHTU1NHEcJAABs6uWmcEZGhlJSUlRXVxdxvK6uTtnZ2a3Kv/POOzpy5IgmT54cPtbc3Px5x7166eDBg7r00ktb1fP5fPL5fG6GBgAAuilX74ykpqZq3LhxqqioCB9rbm5WRUWFCgoKWpW/4oor9Oabb2rPnj3hx2233aYbb7xRe/bs4dcvAADA3TsjklRcXKzp06dr/PjxysvL0+rVq9XQ0KCZM2dKkqZNm6YhQ4aorKxMaWlpGjVqVET99PR0SWp1HAAAXJhch5EpU6boww8/1KJFi1RbW6vc3FyVl5eHP9R69OhRJSdzY1cAAOBMkjHG2B5ER0KhkAKBgILBoPx+v+3hAAAAB5y+fvMWBgAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqT2Fk7dq1GjZsmNLS0pSfn6+dO3e2WXb9+vW6/vrrNWDAAA0YMECFhYXtlgcAABcW12Fk69atKi4uVmlpqXbt2qWxY8eqqKhIx48fj1q+srJSU6dO1SuvvKKqqirl5OTo5ptv1vvvv9/pwQMAgO4vyRhj3FTIz8/XtddeqzVr1kiSmpublZOTo/vvv1/z58/vsH5TU5MGDBigNWvWaNq0aY76DIVCCgQCCgaD8vv9boYLAAAscfr67eqdkTNnzqi6ulqFhYVfNJCcrMLCQlVVVTlq4/Tp0/rss880cODANss0NjYqFApFPAAAQM/kKoycOHFCTU1NysrKijielZWl2tpaR23MmzdPgwcPjgg05ysrK1MgEAg/cnJy3AwTAAB0Iwn9Ns2yZcu0ZcsWPffcc0pLS2uzXElJiYLBYPhRU1OTwFECAIBE6uWmcEZGhlJSUlRXVxdxvK6uTtnZ2e3WXblypZYtW6aXX35ZY8aMabesz+eTz+dzMzQAANBNuXpnJDU1VePGjVNFRUX4WHNzsyoqKlRQUNBmvUcffVRLly5VeXm5xo8f7320AACgx3H1zogkFRcXa/r06Ro/frzy8vK0evVqNTQ0aObMmZKkadOmaciQISorK5MkLV++XIsWLdJTTz2lYcOGhT9b0q9fP/Xr1y+GUwEAAN2R6zAyZcoUffjhh1q0aJFqa2uVm5ur8vLy8Idajx49quTkL95weeKJJ3TmzBn9wz/8Q0Q7paWlWrx4cedGDwAAuj3X9xmxgfuMAADQ/cTlPiMAAACxRhgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFW9bA/AlqZmo52HT+p4/afK7J+mvEsGSlKrYynJSe3WGzd0gKrf+1u7dbzWizbGWJWJNtfzjzmdWyzWOlpfTsYYqzE56d/rXnvdt3iuo9dzxMsYYzH3eIvVdd1d5nsh87rXHbUTy32N5fN6dzkfPYWRtWvXasWKFaqtrdXYsWP12GOPKS8vr83y27Zt08KFC3XkyBFdfvnlWr58uW655RbPg+6s8reOacl/7tOx4KfhY+l9e0uSPj79WfjYoECaSieP1KRRg9qsl5wkNRu1WcdrvWh1YlUm2lyjHXMyt4447f/8vpyOMVZjctK/1732sm9Oxux1Hb2eI17GGIu5x1ssr+vuMN8Lmde9dtJOrPY1ls/r3el8TDLGmI6LfWHr1q2aNm2a1q1bp/z8fK1evVrbtm3TwYMHlZmZ2ar8H/7wB91www0qKyvTt771LT311FNavny5du3apVGjRjnqMxQKKRAIKBgMyu/3uxluK+VvHdPs/9glJ5NuyYpPfPcaSXJU79w6k0YNctyfk75iVcar8+fWETdr3R3H5HSv3e7b+fOI9Tp6PUe8jLGzc483L9dnPPca8eN1r522E4t97cx5FU1XOB+dvn67DiP5+fm69tprtWbNGklSc3OzcnJydP/992v+/Pmtyk+ZMkUNDQ36r//6r/Cxr371q8rNzdW6detiOpmONDUbXbf8vyNSYUeSJGX5fZKSVBtyVi9JUnYgTf/3wRs1ccUrjvtz0ldL28YY1YYaPZfxqqXtV+d9o9239rysdXcck9O9drO3584jXuvo9RzxMkavc483t2vrZq+dXqOJnO+FzOten78/Ts51r/sai+uovTHZOh+dvn67+gDrmTNnVF1drcLCwi8aSE5WYWGhqqqqotapqqqKKC9JRUVFbZaXpMbGRoVCoYhHLOw8fNL1k7qRVBtqdLX5RtKx4Kf6P1VHXPXnpK+Wttt7AXFSxquWtncePtluOS9r3R3H5HSv3eztufOI1zp6PUe8jNHr3OPN7dq62Wun12gi53sh87rX5++Pk3Pd677G4jpqb0xd/Xx0FUZOnDihpqYmZWVlRRzPyspSbW1t1Dq1tbWuyktSWVmZAoFA+JGTk+NmmG06Xp+YF8cW7508ndD+EqmjtUz0WjvpM55jiuVenztOG+voRLzGmMj5eu0rXnuN+PG6zufXc9qOl/66wrlgcwxd8qu9JSUlCgaD4UdNTU1M2s3snxaTdpwaOrBvQvtLpI7WMtFr7aTPeI4plnt97jhtrKMT8RpjIufrta947TXix+s6n1/PaTte+usK54LNMbgKIxkZGUpJSVFdXV3E8bq6OmVnZ0etk52d7aq8JPl8Pvn9/ohHLORdMlCDAmly8xuxJEnZfp+y/c7rJenzTyjfXTDMVX9O+mppO9vv61QZr1rabvkaWVu8rHV3HJPTvXazt+fOI17r6PUc8TJGr3OPN7dr62avnV6jiZzvhczrXp+/P07Oda/7GovrqL0xdfXz0VUYSU1N1bhx41RRURE+1tzcrIqKChUUFEStU1BQEFFekn73u9+1WT6eUpKTVDp5pCQ52syWMotvu0qLb3NWr+XnpZNHKrVXsuP+nPR1btuLb7vKcxmvzm27ow85uV3r7jgmp3vtdm/PnUc81tHrOeJljJ2Ze7y5WVu3e+30GuXDq4nhda/P3x8n57rXfe3sdRRNdzofXf+apri4WOvXr9eTTz6p/fv3a/bs2WpoaNDMmTMlSdOmTVNJSUm4/AMPPKDy8nL94he/0IEDB7R48WK98cYb+v73vx+7WbgwadQgPfHda5QdiHw7akDf3uHvabfIDqSFv+7UVr3z9+7cOu311169turEqkx6lLlGO9bR3Dripv/z+4q2H9HqxWpMTvr3utdu983pmJ2MsaM183o9OB1jZ+ceb7G+rrv6fC9kXvfaaTux2NfOnFdeX8O6yvno+qu9krRmzZrwTc9yc3P1b//2b8rPz5ckff3rX9ewYcO0efPmcPlt27ZpwYIF4ZuePfroo65uehbL+4y04A6s3IGVO7C6a5s7sCZurxE/3IE1sedj3O4zYkM8wggAAIivuNxnBAAAINYIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACretkegBMtN4kNhUKWRwIAAJxqed3u6Gbv3SKM1NfXS5JycnIsjwQAALhVX1+vQCDQ5s+7xd+maW5u1gcffKD+/fsrKSl2f9AnFAopJydHNTU1PfZv3vT0OTK/7q+nz5H5dX89fY7xnJ8xRvX19Ro8eLCSk9v+ZEi3eGckOTlZF198cdza9/v9PfIEO1dPnyPz6/56+hyZX/fX0+cYr/m1945ICz7ACgAArCKMAAAAqy7oMOLz+VRaWiqfz2d7KHHT0+fI/Lq/nj5H5tf99fQ5doX5dYsPsAIAgJ7rgn5nBAAA2EcYAQAAVhFGAACAVYQRAABgVY8PIz//+c81YcIE9e3bV+np6Y7qGGO0aNEiDRo0SH369FFhYaHefvvtiDInT57Ud77zHfn9fqWnp+uee+7RqVOn4jCD9rkdx5EjR5SUlBT1sW3btnC5aD/fsmVLIqYUwcs6f/3rX2819vvuuy+izNGjR3Xrrbeqb9++yszM1IMPPqizZ8/GcyptcjvHkydP6v7779eIESPUp08fffnLX9YPfvADBYPBiHK29nDt2rUaNmyY0tLSlJ+fr507d7Zbftu2bbriiiuUlpam0aNHa8eOHRE/d3I9JpqbOa5fv17XX3+9BgwYoAEDBqiwsLBV+RkzZrTaq0mTJsV7Gm1yM7/Nmze3GntaWlpEma62h27mF+35JCkpSbfeemu4TFfav9///veaPHmyBg8erKSkJP3mN7/psE5lZaWuueYa+Xw+XXbZZdq8eXOrMm6va9dMD7do0SKzatUqU1xcbAKBgKM6y5YtM4FAwPzmN78xf/rTn8xtt91mLrnkEvPJJ5+Ey0yaNMmMHTvW/PGPfzT/8z//Yy677DIzderUOM2ibW7HcfbsWXPs2LGIx5IlS0y/fv1MfX19uJwks2nTpohy584/Ubys88SJE82sWbMixh4MBsM/P3v2rBk1apQpLCw0u3fvNjt27DAZGRmmpKQk3tOJyu0c33zzTXPHHXeYF154wRw6dMhUVFSYyy+/3Pz93/99RDkbe7hlyxaTmppqNm7caPbu3WtmzZpl0tPTTV1dXdTyr732mklJSTGPPvqo2bdvn1mwYIHp3bu3efPNN8NlnFyPieR2jnfddZdZu3at2b17t9m/f7+ZMWOGCQQC5q9//Wu4zPTp082kSZMi9urkyZOJmlIEt/PbtGmT8fv9EWOvra2NKNOV9tDt/D766KOIub311lsmJSXFbNq0KVymK+3fjh07zEMPPWSeffZZI8k899xz7ZZ/9913Td++fU1xcbHZt2+feeyxx0xKSoopLy8Pl3G7Zl70+DDSYtOmTY7CSHNzs8nOzjYrVqwIH/v444+Nz+czv/71r40xxuzbt89IMv/7v/8bLvPiiy+apKQk8/7778d87G2J1Thyc3PNP/3TP0Ucc3ISx5vX+U2cONE88MADbf58x44dJjk5OeIJ84knnjB+v980NjbGZOxOxWoPn376aZOammo+++yz8DEbe5iXl2fmzJkT/n9TU5MZPHiwKSsri1r+H//xH82tt94acSw/P9/88z//szHG2fWYaG7neL6zZ8+a/v37myeffDJ8bPr06eb222+P9VA9cTu/jp5bu9oednb//vVf/9X079/fnDp1KnysK+3fuZw8B/zkJz8xV111VcSxKVOmmKKiovD/O7tmTvT4X9O4dfjwYdXW1qqwsDB8LBAIKD8/X1VVVZKkqqoqpaena/z48eEyhYWFSk5O1uuvv56wscZiHNXV1dqzZ4/uueeeVj+bM2eOMjIylJeXp40bN3b4J6BjrTPz+9WvfqWMjAyNGjVKJSUlOn36dES7o0ePVlZWVvhYUVGRQqGQ9u7dG/uJtCNW51IwGJTf71evXpF/biqRe3jmzBlVV1dHXDvJyckqLCwMXzvnq6qqiigvfb4XLeWdXI+J5GWO5zt9+rQ+++wzDRw4MOJ4ZWWlMjMzNWLECM2ePVsfffRRTMfuhNf5nTp1SkOHDlVOTo5uv/32iOuoK+1hLPZvw4YNuvPOO/WlL30p4nhX2D8vOroGY7FmTnSLP5SXSLW1tZIU8ULV8v+Wn9XW1iozMzPi57169dLAgQPDZRIhFuPYsGGDrrzySk2YMCHi+MMPP6xvfOMb6tu3r37729/qX/7lX3Tq1Cn94Ac/iNn4O+J1fnfddZeGDh2qwYMH689//rPmzZungwcP6tlnnw23G21/W36WSLHYwxMnTmjp0qW69957I44neg9PnDihpqamqGt74MCBqHXa2otzr7WWY22VSSQvczzfvHnzNHjw4Ign90mTJumOO+7QJZdconfeeUc//elP9c1vflNVVVVKSUmJ6Rza42V+I0aM0MaNGzVmzBgFg0GtXLlSEyZM0N69e3XxxRd3qT3s7P7t3LlTb731ljZs2BBxvKvsnxdtXYOhUEiffPKJ/va3v3X6nHeiW4aR+fPna/ny5e2W2b9/v6644ooEjSi2nM6vsz755BM99dRTWrhwYaufnXvs6quvVkNDg1asWBGTF7J4z+/cF+XRo0dr0KBBuummm/TOO+/o0ksv9dyuG4naw1AopFtvvVUjR47U4sWLI34Wzz2EN8uWLdOWLVtUWVkZ8SHPO++8M/zv0aNHa8yYMbr00ktVWVmpm266ycZQHSsoKFBBQUH4/xMmTNCVV16pX/7yl1q6dKnFkcXehg0bNHr0aOXl5UUc787711V0yzDyox/9SDNmzGi3zPDhwz21nZ2dLUmqq6vToEGDwsfr6uqUm5sbLnP8+PGIemfPntXJkyfD9TvD6fw6O45nnnlGp0+f1rRp0zosm5+fr6VLl6qxsbHTf78gUfNrkZ+fL0k6dOiQLr30UmVnZ7f6JHhdXZ0kxWT/pMTMsb6+XpMmTVL//v313HPPqXfv3u2Wj+UeRpORkaGUlJTwWraoq6trcy7Z2dntlndyPSaSlzm2WLlypZYtW6aXX35ZY8aMabfs8OHDlZGRoUOHDiX0xawz82vRu3dvXX311Tp06JCkrrWHnZlfQ0ODtmzZoocffrjDfmztnxdtXYN+v199+vRRSkpKp88JR2L26ZMuzu0HWFeuXBk+FgwGo36A9Y033giXeemll6x9gNXrOCZOnNjqGxht+dnPfmYGDBjgeaxexGqdX331VSPJ/OlPfzLGfPEB1nM/Cf7LX/7S+P1+8+mnn8ZuAg54nWMwGDRf/epXzcSJE01DQ4OjvhKxh3l5eeb73/9++P9NTU1myJAh7X6A9Vvf+lbEsYKCglYfYG3vekw0t3M0xpjly5cbv99vqqqqHPVRU1NjkpKSzPPPP9/p8brlZX7nOnv2rBkxYoSZO3euMabr7aHX+W3atMn4fD5z4sSJDvuwuX/nksMPsI4aNSri2NSpU1t9gLUz54SjscaspS7qvffeM7t37w5/fXX37t1m9+7dEV9jHTFihHn22WfD/1+2bJlJT083zz//vPnzn/9sbr/99qhf7b366qvN66+/bl599VVz+eWXW/tqb3vj+Otf/2pGjBhhXn/99Yh6b7/9tklKSjIvvvhiqzZfeOEFs379evPmm2+at99+2zz++OOmb9++ZtGiRXGfz/nczu/QoUPm4YcfNm+88YY5fPiwef75583w4cPNDTfcEK7T8tXem2++2ezZs8eUl5ebiy66yOpXe93MMRgMmvz8fDN69Ghz6NChiK8Tnj171hhjbw+3bNlifD6f2bx5s9m3b5+59957TXp6evibS3fffbeZP39+uPxrr71mevXqZVauXGn2799vSktLo361t6PrMZHcznHZsmUmNTXVPPPMMxF71fIcVF9fb3784x+bqqoqc/jwYfPyyy+ba665xlx++eUJD8de5rdkyRLz0ksvmXfeecdUV1ebO++806SlpZm9e/eGy3SlPXQ7vxbXXXedmTJlSqvjXW3/6uvrw69zksyqVavM7t27zXvvvWeMMWb+/Pnm7rvvDpdv+Wrvgw8+aPbv32/Wrl0b9au97a1ZLPT4MDJ9+nQjqdXjlVdeCZfR/78fQ4vm5mazcOFCk5WVZXw+n7npppvMwYMHI9r96KOPzNSpU02/fv2M3+83M2fOjAg4idLROA4fPtxqvsYYU1JSYnJyckxTU1OrNl988UWTm5tr+vXrZ770pS+ZsWPHmnXr1kUtG29u53f06FFzww03mIEDBxqfz2cuu+wy8+CDD0bcZ8QYY44cOWK++c1vmj59+piMjAzzox/9KOJrsYnkdo6vvPJK1HNakjl8+LAxxu4ePvbYY+bLX/6ySU1NNXl5eeaPf/xj+GcTJ04006dPjyj/9NNPm6985SsmNTXVXHXVVWb79u0RP3dyPSaamzkOHTo06l6VlpYaY4w5ffq0ufnmm81FF11kevfubYYOHWpmzZoV0yd6t9zM74c//GG4bFZWlrnlllvMrl27Itrranvo9hw9cOCAkWR++9vftmqrq+1fW88PLXOaPn26mThxYqs6ubm5JjU11QwfPjzi9bBFe2sWC0nGJPj7mgAAAOfgPiMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACr/h/wuZBZs5pA8QAAAABJRU5ErkJggg==\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(train_x.squeeze(-1).cpu(), train_y.cpu(), 'o')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following steps create the dataloader objects. See the [SVQEP regression notebook](./SVQEP_Regression_CUDA.ipynb) for details." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import TensorDataset, DataLoader\n", "\n", "train_dataset = TensorDataset(train_x, train_y)\n", "train_loader = DataLoader(train_dataset, batch_size=100000, shuffle=False)\n", "\n", "test_dataset = TensorDataset(test_x, test_y)\n", "test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Variational Inference with PG Auxiliaries\n", "\n", "We define a Bernoulli likelihood that leverages Pòlya-Gamma augmentation. \n", "It turns out that we can derive closed form updates for the Pòlya-Gamma auxiliary variables. To deal with the Q-Exponential Process we introduce inducing points and inducing locations. \n", "In particular we will need to learn a variational covariance matrix and a variational mean vector that control the inducing points. (See the discussion in the [SVQEP tutorial](Approximate_QEP_Objective_Functions.ipynb) for more details.) \n", "We will use natural gradient updates to deal with these two variational parameters; this will allow us to take large steps, thus yielding fast convergence. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class PGLikelihood(qpytorch.likelihoods._OneDimensionalLikelihood):\n", " # this method effectively computes the expected log likelihood \n", " # contribution to Eqn (10) in Reference [1].\n", " def expected_log_prob(self, target, input, *args, **kwargs):\n", " mean, variance = input.mean, input.variance\n", " # Compute the expectation E[f_i^2]\n", " raw_second_moment = variance + mean.pow(2)\n", "\n", " # Translate targets to be -1, 1\n", " target = target.to(mean.dtype).mul(2.).sub(1.)\n", "\n", " # We detach the following variable since we do not want\n", " # to differentiate through the closed-form PG update.\n", " c = raw_second_moment.detach().sqrt()\n", " # Compute mean of PG auxiliary variable omega: 0.5 * Expectation[omega]\n", " # See Eqn (11) and Appendix A2 and A3 in Reference [1] for details.\n", " half_omega = 0.25 * torch.tanh(0.5 * c) / c\n", "\n", " # Expected log likelihood\n", " res = 0.5 * target * mean - half_omega * raw_second_moment\n", " # Sum over data points in mini-batch\n", " res = res.sum(dim=-1)\n", "\n", " return res\n", " \n", " # define the likelihood\n", " def forward(self, function_samples):\n", " return torch.distributions.Bernoulli(logits=function_samples)\n", " \n", " # define the marginal likelihood using Gauss Hermite quadrature\n", " def marginal(self, function_dist):\n", " prob_lambda = lambda function_samples: self.forward(function_samples).probs\n", " probs = self.quadrature(prob_lambda, function_dist)\n", " return torch.distributions.Bernoulli(probs=probs)\n", " \n", "POWER = 1.0\n", "# define the actual QEP model (kernels, inducing points, etc.) \n", "class QEPModel(qpytorch.models.ApproximateQEP):\n", " def __init__(self, inducing_points):\n", " self.power = torch.tensor(POWER)\n", " variational_distribution = qpytorch.variational.NaturalVariationalDistribution(inducing_points.size(0), power=self.power)\n", " variational_strategy = qpytorch.variational.VariationalStrategy(\n", " self, inducing_points, variational_distribution, learn_inducing_locations=True\n", " )\n", " super(QEPModel, self).__init__(variational_strategy)\n", " self.mean_module = qpytorch.means.ZeroMean()\n", " self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())\n", "\n", " def forward(self, x):\n", " mean_x = self.mean_module(x)\n", " covar_x = self.covar_module(x)\n", " return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)\n", "\n", "# we initialize our model with M = 30 inducing points\n", "M = 30\n", "inducing_points = torch.linspace(-2., 2., M, dtype=train_x.dtype, device=train_x.device).unsqueeze(-1)\n", "model = QEPModel(inducing_points=inducing_points)\n", "model.covar_module.base_kernel.initialize(lengthscale=0.2)\n", "likelihood = PGLikelihood()\n", "\n", "if torch.cuda.is_available():\n", " model = model.cuda()\n", " likelihood = likelihood.cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup optimizers\n", "\n", "We will use a `NGD` (Natural Gradient Descent) optimizer to deal with the inducing point covariance matrix and corresponding mean vector, while we will use the `Adam` optimizer for all other parameters (the kernel hyperparmaeters as well as the inducing point locations). \n", "Note that we use a pretty large learning rate for the `NGD` optimizer." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "variational_ngd_optimizer = qpytorch.optim.NGD(model.variational_parameters(), num_data=train_y.size(0), lr=0.1)\n", "\n", "hyperparameter_optimizer = torch.optim.Adam([\n", " {'params': model.hyperparameters()},\n", " {'params': likelihood.parameters()},\n", "], lr=0.01)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define training loop" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "61c1e53ad2ff4c51b1791cbe2df29845", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/100 [00:00]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAi8AAAGeCAYAAABcquEJAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjMsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvZiW1igAAAAlwSFlzAAAPYQAAD2EBqD+naQAAQ2ZJREFUeJzt3Xt0VOW9//HPJJIAlQQpkIAGDGpBBcFLwdCWQGUJ1KIebY/angrWYtvjpQo/KakiIrVRwMtaHqrSo4JtObRa71qrYgAvUSpKvXBZxQZBJBGlJFw0QPL8/hj2ZGayZ7JnsvfM7Jn3a61ZYWb25dmXmfmyn+/z3QFjjBEAAIBP5KW7AQAAAIkgeAEAAL5C8AIAAHyF4AUAAPgKwQsAAPAVghcAAOArBC8AAMBXCF4AAICvELwAAABfOSLdDXBba2urPvnkE/Xo0UOBQCDdzQEAAA4YY7Rnzx71799feXkdXFsxHvrNb35jzjjjDHPkkUeaPn36mPPOO89s3Lixw/n+/Oc/m8GDB5vCwkIzdOhQ8+yzzzpe57Zt24wkHjx48ODBg4cPH9u2bevwt97TKy+rVq3SlVdeqa9//es6dOiQfvWrX+nss8/W+vXr9ZWvfMV2ntdff12XXHKJqqur9d3vflfLli3T+eefr7fffltDhw7tcJ09evSQJG3btk1FRUWubg8AAPBGU1OTysrKQr/j8QSMSd2NGXfu3Km+fftq1apVGjNmjO00F110kfbt26dnnnkm9NqZZ56pESNG6L777utwHU1NTSouLlZjYyPBCwAAPpHI73dKE3YbGxslSb169Yo5TW1trcaPHx/x2oQJE1RbW2s7fXNzs5qamiIeAAAge6UseGltbdW1116rb3zjG3G7f+rr61VSUhLxWklJierr622nr66uVnFxcehRVlbmarsBAEBmSVnwcuWVV+r999/X8uXLXV1uVVWVGhsbQ49t27a5unwAAJBZUjJU+qqrrtIzzzyj1atX65hjjok7bWlpqRoaGiJea2hoUGlpqe30hYWFKiwsdK2tAAAgs3l65cUYo6uuukqPP/64Xn75ZZWXl3c4T0VFhVasWBHx2osvvqiKigqvmgkAAHzE0ysvV155pZYtW6Ynn3xSPXr0COWtFBcXq1u3bpKkSy+9VEcffbSqq6slSb/4xS9UWVmpO+64Q+ecc46WL1+ut956S4sXL/ayqQAAwCc8vfJy7733qrGxUWPHjlW/fv1Cjz/96U+habZu3aodO3aEno8ePVrLli3T4sWLNXz4cD366KN64oknHNV4AQAA2S+ldV5SgTovAAD4T8bWeQEAAOgsghcAQHJqqqVV8+3fWzU/+D7gAYIXAEDH7AKVvHyp5lZp6eTIQGXV/ODrefmpbSNyRkrqvAAAfM4KVCSpcmbb3y2vSHWr26azApdxNwTfr6kOzmvNE27VfKm1RRpX5X37kVUIXgAA7UUHHdbfmluDAcuA0cH361ZL5WOCf+f1kVoOtAUukn3QI0UGOUCCCF4AAO11dKVly2uSaWkLVKzAJb8gMkgJD3qs59FXZ4AEEbwAQK6z69oJDzrqVktTnwkGHXWrpUB+MHCxApVV89sCl5YDweexlrWyWjKt9oEL3UhwiIRdAMh11lWWWCOHtrwSvLJSc2uwi8gKXFoOHE7WPXwVZfbO4F+7ZVXODM5jWu3XQZIvEsCVFwDIdR117axeEAxUAodzXKyrJksnt+W82OXGhD+PvjpDNxI6geAFABAZdFjBipVMawUupiUyUBkwOvi3bnVkV5H1t7Ul+Dc6OLGeR6+LwAUOEbwAQK6JNXy5cqa08ra2KyRSW9DR2iJtfT0yULFyU6xclehlWe9FX1UJD5TsknyBDhC8AECuiTV8eenk4NWVQH5b1070FRErGAmfN17g0doS/6qKta7oJF8gDoIXAMg1dnkp4fkrU56Wlnw3mKgba97oKy2x2I0citWNFL58IA6CFwDIRRHDl29ry2eZ8nTwdWtotF1Q0ZkAo6NupM4uHzmB4AUAslm88vySFMhr6yqyAhdLoldZnIjVjRS9Lm4rgDgIXgAgm3VUnl+KXVwueh43xAs4wtfFbQUQB8ELAGSzeDVcpMzNO+G2AoiD4AUAsp1dDRcp8/NOYtWeyYS2Ia24PQAA5AKrPL9VcC5W3olV0yVThLebejA4jOAFALJFTXXs+xMtndwWAJg4wUl48blMYHfTR+Q8ghcAyBaxbrAYXsMl3s0TM014jouf2g3PkfMCANnCSfG5WNNlGurBIA6CFwDIJnZJruGBS/R0mZTfEs5pPRjkpIAxxqS7EW5qampScXGxGhsbVVRUlO7mAEB6zOvTlisye2e6WwN0KJHfb3JeAMCvYiXoWkmu4Tc9BLII3UYA4Fd2VWjDi82NnRX8S44IsgzBCwD4VXQCa/i/o/NFsjGA4f5HOYvgBQD8LDyACRzOBIgOXLI1yZX7H+UsghcA8LvKmW0ji2JVoc2mKy4W7n+UswheAMDv7KrQ5soPN/c/ykmMNgIAP6MKLfc/ykFceQEAv6IKbVAuX3nKUQQvAOBXVKFtH8CFDxUngMlannYbrV69WpMnT1b//v0VCAT0xBNPxJ1+5cqVCgQC7R719fVeNhMA/GlcVewf6Ey7O7QXogOXmurg63ZdZ6vmt70P3/M0eNm3b5+GDx+uRYsWJTTfpk2btGPHjtCjb9++HrUQAOBb0VeewodOj7uh7cqTFeTk5aennXCdp91GkyZN0qRJkxKer2/fvurZs6f7DQIAZI/oK0vh+T7jbgi+z9DprJSROS8jRoxQc3Ozhg4dqptvvlnf+MY3Yk7b3Nys5ubm0POmpqZUNBEAkIkYOp0TMmqodL9+/XTffffpL3/5i/7yl7+orKxMY8eO1dtvvx1znurqahUXF4ceZWVlKWwxACDjMHQ66wWMMSYlKwoE9Pjjj+v8889PaL7KykoNGDBAv//9723ft7vyUlZW5uiW2gCALGR1FVkBDFdefKGpqUnFxcWOfr8zstso3MiRI/Xqq6/GfL+wsFCFhYUpbBEAIGMxdDonZHzwsm7dOvXr1y/dzQAAZDqK9uUMT4OXvXv3avPmzaHndXV1WrdunXr16qUBAwaoqqpK27dv18MPPyxJuvvuu1VeXq6TTz5ZX375pf73f/9XL7/8sl544QUvmwkAyAZ2RftqqoNDpMOHTltWzT88T5bXw8lCngYvb731lsaNGxd6Pn36dEnSlClTtGTJEu3YsUNbt24NvX/gwAHNmDFD27dvV/fu3XXKKafopZdeilgGAAC27IIQq/aLNXTaEn6VBr6TsoTdVEkk4QcAkANi5cGQyJtRsiphFwCyntW1YfdDStdG51H7JetkVJ0XAMhJVtdG+L14JMrau4naL1mFKy8AkG52I2Lo2nDXqvltgUvLgeBz9qtvEbwAQCaga8M71H7JOgQvAJApKme2BS50bbiD2i9ZieAFADIFXRvui679Ep0cHV77heRo3yB4AYBMQNeGN6IDESs5Worcr9R98RWCFwBIN7o2Uofk6KxA8AIAqWbXdRF+xcXqurDr2kDnkRztewQvAJBq0V0XVteGXdcFP6jeIDna1wheACDV6LpIP5KjfY3gBQDSga6L9CE52vcIXgAgXei6SD2So7MCwQsApAtdF6kXXffFQnK0rxC8AEA60HWRHvEK0LHffYPgBQBSja4LoFPy0t0AAMhqNdXBYCWc1XVhvW+pnBl8na4LIC6CFwDwklXTJTyAsbouam4Nvh8uvO4LAFt0GwGAl6jpAriO4AUAvEZNF8BVdBsBQCpUzmwbEk1NF6BTCF4AIBXsaroASArBCwB4LTzHZfbO4N/oJF4AjpHzAgBeoqYL4DqCFwBwU011cPhzeLn58Cq6rS3BodCUo89c0ccwXPgxRNoQvACAm6y6LlJkzZbwKzAWrrhkpuhjaLE7hkgLghcAcBN1XfyPY5jxCF4AwG3UdfE/jmFGY7QRAHiBui7+xzHMWAQvAOAF6rr4H8cwYxG8AIDbqOvifxzDjEbOCwC4ibou/scxzHgELxmgpdVoTd0ufbrnS/Xt0VUjy3spPy/g+jypWFYqlovU7ttk1xU93+kDj9Laj/7tyTnrZNl22yGpw+VET9Nhuw/XdWn51vVa8+HnbfN963rlW+/HaVMy+8TJ/rDbjmS2P9l9lOz54GQfJbP9cdcfXpsnXIzaPG6dj8mcs25+9pNtUzq+1wPGGJPytXqoqalJxcXFamxsVFFRUbqb06Hn39+huU+v147GL0Ov9SvuqjmTT9LEof1cm8fN9adzuUjtvk12XXbz5QWk1rBvGzfP2Y6WbTdPz+5dJEm79x+MuRy7aSKWHaOY2fPv79DWx2/W/uYDuvvQ9xy3KZl94mR/2G1Hstuf8D5y2Ea7bXeyj5Ld/kw7H5M9Z9367CfbJje/exL5/fY0eFm9erUWLFigtWvXaseOHXr88cd1/vnnx51n5cqVmj59uj744AOVlZXpxhtv1NSpUx2v00/By/Pv79DP//C2og+AFcPe+1+ntTshkpnHzfWnc7lI7b5Ndl2x5ovm5jkbb9mSHM3jVES7P/99u+6F59/fofX/d6Omd3lUdxz8nu5pucBxmxLdJ073R6ols/+jt93J+ed02Ym22Xbf2wSqVhuvyn9M+YHWUKAab9mx2pzsdrnx2e/MvnbzuyeR329PE3b37dun4cOHa9GiRY6mr6ur0znnnKNx48Zp3bp1uvbaa/WTn/xEf/vb37xsZlq0tBrNfXq97clpvTb36fVqCQvfk5nHzfU74dVykdp9m+y64s2XyHISbVOsZd/81Ae6+Sln8zgV0e5vXR+RyNnSarT18ZvbBS5O25TIPklkf6Ra+Hbc/NQHCZ8PBw61Ojr/nC7biQ73vVV193DCrrX/r8p/TDO6PKoWE/vnNJFjn+g529nPvtPPeqx9na7vdU9zXiZNmqRJkyY5nv6+++5TeXm57rjjDknSiSeeqFdffVV33XWXJkyYYDtPc3OzmpubQ8+bmpo61+gUWVO3K+LSWzQjaUfjl1pTt0sVx3016XncXL8TXi0Xqd23ya6ro/mcLieZNtktu76pucPpkhHR7rBEzsCqBbqi9UC7wCWRNjndJ4nuj1SztiOZeX5fu8XR+ee2uPs+KmF3zTGX63t7l2mGTaAaa9lOjn0y52xnPvtOP+terT9ZGTVUura2VuPHj494bcKECaqtrY05T3V1tYqLi0OPsrIyr5vpik/3OPvghU+XzDxurt/N6RNdLlK7b5NdV7LrdvOcTaVQmw4XM8trPaBmc0SHP2QJLTvJ9/3so13707r+mPu2cmboStvIPw52HLikSjLnhJvnUSrPyYwKXurr61VSUhLxWklJiZqamvTFF1/YzlNVVaXGxsbQY9u2baloaqf17dE14emSmcfN9bs5faLLRWr3bbLrSnbdbp6zqRRq0+FiZq15BSoMHNLV+Y+5t+wk3/ezgb26p3X9cfft4UA1v/Wga4GqW5I5J9w8j1J5TmZU8JKMwsJCFRUVRTz8YGR5L/Ur7qpYA8wCCmZxW0Pjkp3HzfU74dVykdp9m+y6OprP6XKSaZPdskuLClVa5HwepyLaHVYTxNz4qRbnX6wZXR61DWCctMnpPkl0f6SatR2lRYUJnw8/qjjW0fmXyLITWX/cfX84UDX5wUD1GoeBqtNjn8w525nPvtPPerx9nY7v9YwKXkpLS9XQ0BDxWkNDg4qKitStW7c0tcob+XkBzZl8kiS1OyGs53MmnxQxfj6ZedxcvxNeLRep3bfJrivefNHcPGdjLfvmc0/Wzec6m8epiHa/siBitFF+XkAD/uNm3Xnwe+0CGCdtSmSfJLI/Ui18O24+9+SI15zMU3BEnqPzz+mynXC078MC1cDsnfrnSddoepdHOwxgEjn2iZ6znf3sO/2sx9rX6fpez6jgpaKiQitWrIh47cUXX1RFRUWaWuStiUP76d7/Ok2lxZGX2kqLu8YcdpbMPG6uP53LRWr3bbLrijVf9Peam+dsvGXHmqdn9y6huhmxlhM9zbVHPKqqrzzV1u7wYmar5ks11Zo4tJ9OuuTXWpx/sfIDrQm1KdF94nR/2G3rUUlsv9NpnGxrR+eDk33kdNlOtrXDfW9TdfeE/5wXCmDCA9Vkzkcn09hthxuf/c60KV3f657Wedm7d682b94sSTr11FN15513aty4cerVq5cGDBigqqoqbd++XQ8//LCk4FDpoUOH6sorr9SPf/xjvfzyy7rmmmv07LPPxhxtFM1PdV4sVNhFoqiwm54Ku6fW/U4D/nFX++qrNj9sqaxWSoVdlyvs2olRkFCSWlferu279urtQT+jwm4nZEyRupUrV2rcuHHtXp8yZYqWLFmiqVOnasuWLVq5cmXEPNddd53Wr1+vY445RrNnz87aInUAfCg6ULG7Dw6AhGVM8JIOBC8APGcFLPkFUssBApdcFedqjFbNP9ytWJX6dvlUxlTYBYCsdHi4rFoOBP8SuOSmqKq7IVZwm5efnnblAO4qDQCJOjxcNhTArJpPAJOLoqru0o2YOgQvAJCIWDkvEj9WuSg8gFm9gG7EFCF4AQCn7P5Xbfe/b+SWypltgQvdiClB8AIAToXXdQlnPW9tSX2bkH50I6YcwQsA2LEbSWKNHLEbScKPVW6iGzEtCF4AwI41kkSKXZAOuY1uxLQheAEAO4wkQUfoRkwbitQBQDwUpANSgiJ1AOAWCtLBqZrq9gXrLIdv3Al3ELwAQDx2I0kAO1TcTRlyXgAgFkaSIBHkSaUMwQsA2GEkCZJBxd2UIHgBADuMJEGyqLjrOYIXAKAgHdxExV3PkbALACRawi3h3Y2zdwb/2p1b6BSuvAAAiZZwA3lSKUPwAgASiZboPPKkUoYKuwAQbl6ftnyF2TvT3RogZ1BhFwCSQUE6uI2qu54geAEAiURLeINkcE+Q8wIAJFrCKySDe4LgBQBItISXSAZ3HQm7AHKPXVE6i11ROsANJIPHRcIuAMRDHgJSjWRwV9FtBCD3kIeAVOLu5K4jeAGQm8hDQCqQDO4JghcAuYu7/8JrJIN7guAFQO7i7r/wWrzEb861pJGwCyA3UZQO6UDFXVdw5QVA7iEPAelijXSTIs+x8HMSHSJ4AZB7yENAujDSzRUUqQMAINWsgMXKtyJwybwidYsWLdKxxx6rrl27atSoUVqzZk3MaZcsWaJAIBDx6Nq1ayqaCSAbkWOATFQ5sy1wYaRbwjwPXv70pz9p+vTpmjNnjt5++20NHz5cEyZM0KeffhpznqKiIu3YsSP0+Oijj7xuJoBsRTVdZCIq7naK58HLnXfeqWnTpumyyy7TSSedpPvuu0/du3fXgw8+GHOeQCCg0tLS0KOkpMTrZgLIVpUz248kIscA6cRIt07zNGH3wIEDWrt2raqq2sa55+Xlafz48aqtrY053969ezVw4EC1trbqtNNO029+8xudfPLJttM2Nzerubk59Lypqcm9DQCQHaimi0zBSDdXeHrl5bPPPlNLS0u7KyclJSWqr6+3nWfw4MF68MEH9eSTT+oPf/iDWltbNXr0aH388ce201dXV6u4uDj0KCsrc307AGQBcgyQCeKNdBt3AyPdHMq4odIVFRWqqKgIPR89erROPPFE3X///Zo3b1676auqqjR9+vTQ86amJgIYAO1RTReZgIq7rvA0eOndu7fy8/PV0NAQ8XpDQ4NKS0sdLaNLly469dRTtXnzZtv3CwsLVVhY2Om2Ashi3NUXyCqedhsVFBTo9NNP14oVK0Kvtba2asWKFRFXV+JpaWnRe++9p379+nnVTADZLFaOAUmSgG953m00ffp0TZkyRWeccYZGjhypu+++W/v27dNll10mSbr00kt19NFHq7o6WGvhlltu0Zlnnqnjjz9eu3fv1oIFC/TRRx/pJz/5iddNBZCNqKYLZB3Pg5eLLrpIO3fu1E033aT6+nqNGDFCzz//fCiJd+vWrcrLa7sA9O9//1vTpk1TfX29jjrqKJ1++ul6/fXXddJJJ3ndVADZiBwDIOtwewAAAJB2GXd7AAAAALcQvAAAAF8heAEAAL5C8AIAAHyF4AUAgExSUx27/tCq+cH3cxzBCwAAmSQv376AolVwMS8/Pe3KIBl3byMAAHKa3V2m7SpF5zCCFwAAMk14ALN6QfBmogQuIXQbAfAvcgOQzSpntt0FPb+AwCUMwQsA/yI3ANls1fy2wKXlADcRDUO3EQD/IjcA2Sr6PLaeS5zXIngB4HfkBiDb2AXgdoF6DqPbCID/kRuAbNLa0j4At/K3xt0QfD9cDuZ3EbwA8D9yA5BNxlW1D8Ct/C7rfUuO5nfRbQTA38gNQC4gvysCwQsA/yI3ALmE/K4QghcA/mWXGyC1PY/ODQD8rnJmW+CSw/ldBC8A/Cu87z9ajn6pI8vZ5Xfl4LlOwi4A/6CiLnJZeDfp7J3Bv3ZFGnMAV14A+Ef4iIvw/22Gf6kD2Yj8rggELwD8gxEXyFXR+V011cFg3i6/a9X8w9PH6Vb1OYIXAP7CiAvkouhAJMevQhK8APAfRlwg1+X4VUiCFwD+w4gLIKevQjLaCIC/MOICaJOj9/XiygsA/2DEBRApR69CErwA8A8q6gJtcvi+XgQvAPyDirpAUI5fhSR4AZCZoutYhMuBOhZAXDl+FZLgBUBmyvE6FkBcOX4VkuAFQGbK8ToWAGIjeAGQuXK4jgWA2KjzAiCz5WgdCwCxpSR4WbRokY499lh17dpVo0aN0po1a+JO/8gjj2jIkCHq2rWrhg0bpueeey4VzQSQiezqWADIaZ4HL3/60580ffp0zZkzR2+//baGDx+uCRMm6NNPP7Wd/vXXX9cll1yiyy+/XO+8847OP/98nX/++Xr//fe9biqATEM1XQA2AsYY4+UKRo0apa9//ev6n//5H0lSa2urysrKdPXVV2vWrFntpr/ooou0b98+PfPMM6HXzjzzTI0YMUL33Xdfh+trampScXGxGhsbVVRU5N6GAEitWMm5JO0Czvmo5EAiv9+eXnk5cOCA1q5dq/Hjx7etMC9P48ePV21tre08tbW1EdNL0oQJE2JO39zcrKampogHgCwQr47FuBuyvo4F4Aqr5ED01UrrPwF5+elpVyd5Otros88+U0tLi0pKSiJeLykp0caNG23nqa+vt52+vr7edvrq6mrNnTvXnQYDyBw5XscCcEWWlhzw/VDpqqoqTZ8+PfS8qalJZWVlaWwRAAAZJAtLDnjabdS7d2/l5+eroaEh4vWGhgaVlpbazlNaWprQ9IWFhSoqKop4AACAMFlWcsDT4KWgoECnn366VqxYEXqttbVVK1asUEVFhe08FRUVEdNL0osvvhhzegAA0IEsKzngebfR9OnTNWXKFJ1xxhkaOXKk7r77bu3bt0+XXXaZJOnSSy/V0UcfrerqaknSL37xC1VWVuqOO+7QOeeco+XLl+utt97S4sWLvW4qAADZJzzHpbVF2vp67PuGZdDoo3g8r/Ny0UUXaeHChbrppps0YsQIrVu3Ts8//3woKXfr1q3asWNHaPrRo0dr2bJlWrx4sYYPH65HH31UTzzxhIYOHep1UwEAyC7Rybl5+VLdaql8TOQoJJ+NPvK8zkuqUecFAIDD7Oq8WIFK+RhpwOi24dRpTuJN5Pfb96ONAABADHZdQOGjj7a+4cvRR9yYEQCAXOPz0UcELwAA5Jro0UdLJ8eerqY6tW1zgOAFAIBcEn3D0/IxwSTe6AAmg5N4CV4ApF5Ndew6Exn6Pz0gK9jdGmDK0+0DmAy/hQDBC4DUy9KbxQEZL9YNT8MDmHl9MjpwkRhtBCAdsvRmcUDGi1eAbsrTwcDFB0m8BC8A0iMLbxYH+JqVxBvIb7uFQPTnMUOq8NJtBCB9fD5cE8ga4Vc+x84KvhbdtZtB3bpceQGQPnY3iyOAAVIrVpdtza1tXbvW8wy5OkrwAiA9or8wredSRnw5AjnDLok3vFt3ZbVkWjMmcJEIXgCkg93/9OySeAF4L1b+SuXMtny0QJyuojTkwZDzAiD1Yg3XrJwZfL21JT3tAtAmvFvXtGRUeQPuKg0AACLF69aNfs2l7qREfr8JXgB4q6Y6+L8yuy+3DBl2CSBMrKAkPICxkuxdzINJ5PebbiMA3qKaLuAvHXXrBvLSXt6AhF0A3qKaLuAvHV0JNa1pL29A8ALAe1TTBfwvg8obELwASI3wYZdU0wX8JcPKGxC8AEgNqukC/hUvD8Z6P4UIXgB4L4MuNwNIQrw8GHJeAGSdDLvcDMD/CF4AeCvDLjcD8D+K1AEAgLSjSB0AAMhaBC8A3FVT3b6armXV/OD7ANAJBC8A3MXtAAB4jIRdAO7idgAAPEbwAsB93A4AgIfoNgLgjcqZbdV0uR0AABcRvADwht3tAADABQQvANwXnuMye2fwr10SLwAkgZwXAO7idgAAPObplZddu3bphz/8oYqKitSzZ09dfvnl2rt3b9x5xo4dq0AgEPH42c9+5mUzAbgp3u0Axt3A7QAAdJqntweYNGmSduzYofvvv18HDx7UZZddpq9//etatmxZzHnGjh2rr33ta7rllltCr3Xv3t1xqX9uDwAAgP8k8vvtWbfRhg0b9Pzzz+vvf/+7zjjjDEnSPffco+985ztauHCh+vfvH3Pe7t27q7S01KumAQAAH/Os26i2tlY9e/YMBS6SNH78eOXl5enNN9+MO+8f//hH9e7dW0OHDlVVVZX2798fc9rm5mY1NTVFPACkCLcCAJAGnl15qa+vV9++fSNXdsQR6tWrl+rr62PO94Mf/EADBw5U//799e677+qXv/ylNm3apMcee8x2+urqas2dO9fVtgNwyLoVgBSZ4xKetAsALks4eJk1a5Zuv/32uNNs2LAh6QZdccUVoX8PGzZM/fr101lnnaUPP/xQxx13XLvpq6qqNH369NDzpqYmlZWVJb1+AAngVgAA0iDh4GXGjBmaOnVq3GkGDRqk0tJSffrppxGvHzp0SLt27Uoon2XUqFGSpM2bN9sGL4WFhSosLHS8PAAu41YAAFIs4eClT58+6tOnT4fTVVRUaPfu3Vq7dq1OP/10SdLLL7+s1tbWUEDixLp16yRJ/fr1S7SpAFKlcmZb4MKtAAB4zLOE3RNPPFETJ07UtGnTtGbNGr322mu66qqrdPHFF4dGGm3fvl1DhgzRmjVrJEkffvih5s2bp7Vr12rLli166qmndOmll2rMmDE65ZRTvGoqgM7iVgAAUsjTInV//OMfNWTIEJ111ln6zne+o29+85tavHhx6P2DBw9q06ZNodFEBQUFeumll3T22WdryJAhmjFjhi688EI9/fTTXjYTQGdwKwAAKeZpkbp0oEgdkEKxknNJ2gWQoIwoUgcgB8S7FYD1PgC4jCsvAAAg7RL5/fY05wVAlqGiLoAMQPACwDmrom50AGPluOTlp6ddAHIKOS8AnKOiLoAMQPACIDFU1AWQZnQbAUhc5cy2gnRU1AWQYgQvABJHRV0AaUTwAiAxVNQFkGbkvABwzi451y6JFwA8RPACwDkq6gLIAFTYBQAAaUeFXQAAkLUIXgDY41YAADIUwQsAe9wKAECGImEXgD1uBQAgQxG8AIiNWwEAyEB0GwGIj1sBAMgwBC8A4uNWAAAyDMELgNi4FQCADETOCwB73AoAQIYieAFgj1sBAMhQ3B4AAACkHbcHAJA4KuoC8AmCFwBBVNQF4BPkvAAIoqIuAJ8geAHQhoq6AHyAbiMAkaioCyDDEbwAiERFXQAZjuAFQBsq6gLwAXJeAARRUReATxC8AAiioi4AnyB4cail1WhN3S59uudL9e3RVacPPEprP/p36PnI8l7Kzwu4Ml/0PCPLe0lSu9eczOfWNNHrt9sOJ21MZl8nu36nx8iN9iS7freOmSv7ccAVwWns1hcV0CT7eXDSRre230tufa79sK25LNnvYqfLcuPYOl2uV+dsuhC8OPD8+zs09+n12tH4Zei1vIDUGnZjhX7FXTVn8kmaOLRfp+azm6dn9y6SpN37DyY0n1vT2K0/ejuctNEJN9fv5Bi51Z5k1u/WMetMu5M5jsl+Hpy00a3t95Jbn2s/bGsuS/a72Omy3Di2Tpfr1TmbTp7d2+jWW2/Vs88+q3Xr1qmgoEC7d+/ucB5jjObMmaPf/e532r17t77xjW/o3nvv1QknnOB4vW7f2+j593fo5394Wx3tJCsWvfe/TtPEof2Smk+So3mczufWNMmK3icdcbrP/NoeL46Z3XY4afe1RzyqFpOne1ouaPfe1fmPKT/QqrsPfc9xm6Il20Y3tt9Lbn6uM31bc1kin303zvVkjq3T5Xp1znpxPiby++1Z8DJnzhz17NlTH3/8sR544AFHwcvtt9+u6upqLV26VOXl5Zo9e7bee+89rV+/Xl27dnW0XjeDl5ZWo2/e/nJE5BlPQFJpcVetun6cKhfUJDyfMUb1Tc2O2xeQVFJUKCmg+ib7dTlZdrLrd9rG0uKuevWX3457uTHRfe3X9jg9Hk6Pa/R2OG331fmPaUaXR3XHwe9FBDCxXk/mHEm2jZ3Zfi8l+33Q2c9eOrY1lyXz2e/MuZ7MsXW6XC9+i7w8HzPixoxz587Vddddp2HDhjma3hiju+++WzfeeKPOO+88nXLKKXr44Yf1ySef6Iknnog5X3Nzs5qamiIebllTtyuhE9hI2tH4pX5fuyWp+RINHIyk+qbmmF/wTped7PqdtnFH45daU7cr7nSJ7mu/tsfp8XB6XKO3w2m772m5QHcc/J5mdHlUV+c/Jil24OK03W61sTPb76Vkvw86+9lLx7bmsmQ++50515M5tk6X68VvUaacjxlT56Wurk719fUaP3586LXi4mKNGjVKtbW1Meerrq5WcXFx6FFWVuZamz7dk9yP10e79rvWhmzR0b5Mdl8nK9Pak6zodibS7vAAZlPhpTEDl3S2MdFleynd50S6158rOrOfkz3XE12n0+m9/C1K9/mYMcFLfX29JKmkpCTi9ZKSktB7dqqqqtTY2Bh6bNu2zbU29e3hrKsq2sBe3V1rQ7boaF8mu6+TlWntSVZ0OxNt9z0tF6jZHKHCwCE1myNcD1zs2uTmvk3lcUr3OZHu9eeKzuznZM/1RNfpdHovf4vSfT4mFLzMmjVLgUAg7mPjxo1etdVWYWGhioqKIh5uGVneS/2Ku8ppr15AwWzsH1Ucm9R8pUWFjuex5istKlRpUex1OVl2sut32sZ+xW1DDGNJdF/7tT1Oj4fT4xq9HYm2++r8x0KBS2HgUKgLKZl2u9XGzmy/l5L9PujsZy8d25rLkvnsd+ZcT+bYOl2uF79FmXI+JhS8zJgxQxs2bIj7GDRoUFINKS0tlSQ1NDREvN7Q0BB6L9Xy8wKaM/kkSerw4Fvvz5l8kgqOyEtqvpvPPdnRPOHT3Hzuybr5XPt1OVl2sut3InzZHSV2JbKv/dqeRI6H0+MavR2JtDs8x2Vw88PtcmASaXe0ZNvY2e33UrLfB5357KVrW3NZop/9zp7ryRxbp8t1+7cok87HhIKXPn36aMiQIXEfBQUFSTWkvLxcpaWlWrFiRei1pqYmvfnmm6qoqEhqmW6YuHOJXjjtDZUWBy+RXXtE8Ms9LxD88r/2iEclBbOvXzjtDU3cucR2PovdfNaws4lD++ne/zqt3Tw9u3cJ1RewOJnPrWns1h993h7VQRudcHv90dN42Z5E1+/WMUuk3eFttAKXOw+15bjc03KBFgUuahfAOGlTovvay+33UjLb79dtzWVOPkMWN851N9sYvVyvztl082yo9NatW7Vr1y499dRTWrBggV555RVJ0vHHH68jjzxSkjRkyBBVV1frP/7jPyQFh0rfdtttEUOl33333bQNlZYUut9L69hf6c2yn6jn3+/SiRvvUcuxY5S/ZbU2DLlau79+nUZt+1/lrfxNW3n1qPmOWrNQ3bsWqn/Pbjpi1W9C840s76X8VxYcLs1eJYkKu1TYTXyaRNttnY+lk29q18bA6vnavmuv3h70M8dtosIuFXazERV2U3s+ZkSdl6lTp2rp0qXtXq+pqdHYsWODKw8E9NBDD2nq1KmS2orULV68WLt379Y3v/lN/fa3v9XXvvY1x+t1PXiR2t+wbulkqW61VD5GmvK0/Q3toueT2m5uZ01XUy1tfT24LLt5wwIaAACyWUYEL+niSfAitQUi+QVSy4Fg4FK3uu253Q3t7OaTYgdB0fPEWiaQqJpqKS8/9jlKoAwgzTKiSF3WqZzZFoDkFwSDjfDnsYKM6PnG3RAMTOb1aQtc6lYfDliqgwFNzKs41anZVmSfvPzgebVqfuTrVqCcl5+edgFAErgxo1Or5rcFIC0HgkFG+PNV82P/rzZ8Oql9EGT9gATyJdMSDGjidT8BibLOJ6vrMiwviyt8APyGKy9OhH/Jz97ZdrWkfEzwuXU1Jdb/aq35rOnsgp78gmDgEshvuxIjSUu+G/sHhqsxSETlzMgrfwQuAHyKKy8dif7f6ar57bt7kvlf7Zjr2+bZ8kpkQFM+Jvj66gVtV2vitQtwqnJm23kVr7sTADIYwUtHWlsiA5Dw51aioxT5vt180cGGlSC55RX7kUuB/LYfmDHXtwVGrS2MUELyorsxY3V3AkAGY7RRqtiN9rAClfIx0oDRbUGHNQrJyoEJH2odnhfDCCUkwu4qIucMgAyRyO83V15Sxe5qSPTVGSmyW2rA6LZRIuNuaPvfcnheDD9CcMLuHLHr7gQAHyB4SafogCZeEGL9wMTKi4keoRS+TLqSYBcoS+27OwHABxhtlEli/cBYjv1W26glq1sp+kpMOGp4wDKuKn4tIoJbAD7ClZdMYvcDEutyv5XoG54DQw0PSFTTBZD1CF4yXSJ5MdFdSQQuuck6HySKHQLISgQvmS7RvJjwIdYELrmJaroAshzBi9/ES7y0upKo4YHwAIYrcQCyDAm7fhMr8dLqSoq+FUF0Ei9yR/RNQQlcAGQJrrxkAyc1PFpbSOLMNVTTBZCluPKSDeJ1JY27oS1wiXfzSIZTZ5dYNwXlShyALMCVl2wQ74pJrGJ3JHFmL6rpAshyBC+5hCTO3EA1XQBZjhsz5qJ5fdpyIWbvTHdrAABI6PebnJdcE53EuXRy7OlqqlPbNgAAHCB4ySXRSZzlY4LDq6MDGJJ4AQAZjOAlV9glcU55un0AQxIvACDDkbCbK2IlcU55Ohi41K1uy4UhcAEAZDASdhFkBS6BfGnOrvbvU8gOAOAhEnaRGCuJN5AvmRZyYAAAGY3gJdeF57jM2UUODAAg45HzkstiJfFaOTBzewWvxBC4AAAyCFdeclm8JF6rC4m7EWeWmurY9yeiNg+AHEHwksvGVcW+y7QVuFDILrNwg00AIHhBFArZZTbrTuHhAQx5SQByDEOl0SbWj6CVA1M+JtilxI9l+lnHwLo6xrEA4HOJ/H4TvKBNTXXwSordj6AVwPBjmTm4wSaALEKdFyQnVg6MFLziYgUuJPGmX/QNNmMl8QJAFvIseLn11ls1evRode/eXT179nQ0z9SpUxUIBCIeEydO9KqJSESokF1e7B9LEnhTIzovKToHBgCynGd1Xg4cOKDvf//7qqio0AMPPOB4vokTJ+qhhx4KPS8sLPSieUhE+I+lFPx3za3Bf1tXYKKngTfs8o2sv9HHBACylGfBy9y5cyVJS5YsSWi+wsJClZaWetAiJCVWcm54AGM9Jw/Ge7Fq81jPW1tS3yYASLGMq7C7cuVK9e3bV0cddZS+/e1v69e//rW++tWvxpy+ublZzc3NoedNTU2paGbusPuxDP+f/srbqMLrFbsEauvGmHY3ymT/A8gRGZWwO3HiRD388MNasWKFbr/9dq1atUqTJk1SS0vs/01WV1eruLg49CgrK0thi3NArCTeypnBZFGq8HqHgnQAYCuh4GXWrFntEmqjHxs3bky6MRdffLHOPfdcDRs2TOeff76eeeYZ/f3vf9fKlStjzlNVVaXGxsbQY9u2bUmvHwlgtIv3KEgHALYS6jaaMWOGpk6dGneaQYMGdaY97ZbVu3dvbd68WWeddZbtNIWFhST1plr0D+iS78ZOFrXr3oBz4V10qxdQYwcAlGDw0qdPH/Xp08ertrTz8ccf6/PPP1e/fv1Stk50wO5//uVjpC2vMALJK5Uz2wIXuugAwLucl61bt2rdunXaunWrWlpatG7dOq1bt0579+4NTTNkyBA9/vjjkqS9e/fq+uuv1xtvvKEtW7ZoxYoVOu+883T88cdrwoQJXjUTiYqVwGsFKHWrg3/p3nAPXXQAEMGz0UY33XSTli5dGnp+6qmnSpJqamo0duxYSdKmTZvU2NgoScrPz9e7776rpUuXavfu3erfv7/OPvtszZs3j26hTBKr+ye8e8MqW0/g0nnRQaD1XGLfAshZ3NsI7uJ+O+6JdfWKq1oAslAiv98ZV+cFPmZ3CwESeJNHQToAsEXwAndwCwH3xQvwuOICIIcRvKDzuIUAACCFCF7QedxCAACQQiTswlsk8AIAHEjk9zuj7m2ELEN9EgCABwhe4I3wPJjZO9vfoyfX1VTH3her5gffBwDYIucF7rNL4A3PgQl/nqusO0ZLsWu4AABsEbzAfdQn6ZhdMEfxOQBwhIRdIJ2sgMXKCyJwAZCjSNgF/KJyZlvgwh2jAcARghcgnRiRBQAJI3gB0oURWQCQFBJ2gXRgRBYAJI3gBZmjpjo4hNjuRzvb7kbNiCwASBrBCzJHLtU+4Y7RAJA0ghdkDmqfAAAcIHhBZgkPYFYv8Gftk1zq/gKANGC0ETKP32ufWN1f0aOGrKtIefnpaRcAZAmuvCDzRNc+WTpZmvK0/XSZeBWD7i8A8BRXXpBZomuflI+R6lYHAxi76TL1KkblzLa6LfP6ELgAgIsIXpA57K5OTHm6fQDjl6sYfu/+AoAMRfCCzBGr9kl4ADP3qNiBy6r5wWTZVKuptq+Ka3V/BfIo/Q8ALiLnBZkjXu7KlKeD3S8tB+zfT2ctGLv6NFZ7JGns4e2ici4AuILgBf4QncSbScmw0Qm64f+ObhMBDAB0WsAYY9LdCDc1NTWpuLhYjY2NKioqSndz4Ibo4CT8qoYVzJSPSc2IpHg1XJZODnZtBfIlE6MLLFNHSAFAmiXy+03OCzJbrBsYWt1DLQeCwULd6tTUVYlXwyU8cImVoFs5k8AFADqJbiNktlhJvBYrWCgf03Y1prVF2vp6MJiInjeRKx92V1nCu4jqVktTn2kLkqykYutq0Kr5dA8BgAcIXpDZ7IKMWN1IVgATHtB05gaPsW4UadnySlsSsRW42HVtEcAAgKsIXuAvsbqRpMjAJbwrKdmk3o4q5Vr3XrLWFatN4c8BAJ1G8AJ/idWNVDkzeCUkvNvGuhLTmRs8xrpRpBQ5+in6Kk/4vK0tiW8nACAmEnbhL+OqYt+t2br6MXtn8K+VQNvZCrfRlXKlyFsYWOuyK0JHgi4AuI7gBf4XqyupfExbF1JnKtza1ZixG/1kNwoJAOA6z4KXLVu26PLLL1d5ebm6deum4447TnPmzNGBAzEqpB725Zdf6sorr9RXv/pVHXnkkbrwwgvV0NDgVTORDey6kqwrMeVjpDHXJx9cRN8o8thv2U9nBTB0EQGA5zzLedm4caNaW1t1//336/jjj9f777+vadOmad++fVq4cGHM+a677jo9++yzeuSRR1RcXKyrrrpKF1xwgV577TWvmgq/i+6WiZecm0gCrd1ywodGRy+HpFwASImUVthdsGCB7r33Xv3rX/+yfb+xsVF9+vTRsmXL9L3vfU9SMAg68cQTVVtbqzPPPLPDdVBhF3Gr4Ha2zksyywEAdCiR3++UjjZqbGxUr169Yr6/du1aHTx4UOPHjw+9NmTIEA0YMCBm8NLc3Kzm5ubQ86amJncbDf+JF1AkcnXEreUAAFyVsoTdzZs365577tFPf/rTmNPU19eroKBAPXv2jHi9pKRE9fX1tvNUV1eruLg49CgrK3Oz2QAAIMMkHLzMmjVLgUAg7mPjxo0R82zfvl0TJ07U97//fU2bNs21xktSVVWVGhsbQ49t27a5unwAAJBZEu42mjFjhqZOnRp3mkGDBoX+/cknn2jcuHEaPXq0Fi9eHHe+0tJSHThwQLt37464+tLQ0KDS0lLbeQoLC1VYWOi4/QAAwN8SDl769OmjPn36OJp2+/btGjdunE4//XQ99NBDysuLf6Hn9NNPV5cuXbRixQpdeOGFkqRNmzZp69atqqioSLSpAAAgC3mW87J9+3aNHTtWAwYM0MKFC7Vz507V19dH5K5s375dQ4YM0Zo1ayRJxcXFuvzyyzV9+nTV1NRo7dq1uuyyy1RRUeFopBEAAMh+no02evHFF7V582Zt3rxZxxxzTMR71ujsgwcPatOmTdq/f3/ovbvuukt5eXm68MIL1dzcrAkTJui3v/2tV80EAAA+k9I6L6lAnRcAAPwnkd9v7m0EAAB8heAFAAD4CsELAADwlZTeHiAVrBQebhMAAIB/WL/bTlJxsy542bNnjyRxmwAAAHxoz549Ki4ujjtN1o02am1t1SeffKIePXooEAi4uuympiaVlZVp27ZtWTmSKdu3T8r+bWT7/C/bt5Ht8z+vttEYoz179qh///4dFrXNuisveXl57erKuK2oqChrT0op+7dPyv5tZPv8L9u3ke3zPy+2saMrLhYSdgEAgK8QvAAAAF8heElAYWGh5syZk7V3sc727ZOyfxvZPv/L9m1k+/wvE7Yx6xJ2AQBAduPKCwAA8BWCFwAA4CsELwAAwFcIXgAAgK8QvAAAAF8heAlz6623avTo0erevbt69uzpaB5jjG666Sb169dP3bp10/jx4/XPf/4zYppdu3bphz/8oYqKitSzZ09dfvnl2rt3rwdb0LFE27JlyxYFAgHbxyOPPBKazu795cuXp2KTIiSzr8eOHduu7T/72c8iptm6davOOeccde/eXX379tX111+vQ4cOebkpthLdvl27dunqq6/W4MGD1a1bNw0YMEDXXHONGhsbI6ZL5/FbtGiRjj32WHXt2lWjRo3SmjVr4k7/yCOPaMiQIeratauGDRum5557LuJ9J5/JVEpk+373u9/pW9/6lo466igdddRRGj9+fLvpp06d2u5YTZw40evNiCuRbVyyZEm79nft2jViGj8fQ7vvk0AgoHPOOSc0TSYdw9WrV2vy5Mnq37+/AoGAnnjiiQ7nWblypU477TQVFhbq+OOP15IlS9pNk+jnOmEGITfddJO58847zfTp001xcbGjeW677TZTXFxsnnjiCfOPf/zDnHvuuaa8vNx88cUXoWkmTpxohg8fbt544w3zyiuvmOOPP95ccsklHm1FfIm25dChQ2bHjh0Rj7lz55ojjzzS7NmzJzSdJPPQQw9FTBe+D1IlmX1dWVlppk2bFtH2xsbG0PuHDh0yQ4cONePHjzfvvPOOee6550zv3r1NVVWV15vTTqLb995775kLLrjAPPXUU2bz5s1mxYoV5oQTTjAXXnhhxHTpOn7Lly83BQUF5sEHHzQffPCBmTZtmunZs6dpaGiwnf61114z+fn5Zv78+Wb9+vXmxhtvNF26dDHvvfdeaBonn8lUSXT7fvCDH5hFixaZd955x2zYsMFMnTrVFBcXm48//jg0zZQpU8zEiRMjjtWuXbtStUntJLqNDz30kCkqKopof319fcQ0fj6Gn3/+ecS2vf/++yY/P9889NBDoWky6Rg+99xz5oYbbjCPPfaYkWQef/zxuNP/61//Mt27dzfTp08369evN/fcc4/Jz883zz//fGiaRPdZMghebDz00EOOgpfW1lZTWlpqFixYEHpt9+7dprCw0Pzf//2fMcaY9evXG0nm73//e2iav/71ryYQCJjt27e73vZ43GrLiBEjzI9//OOI15yc9F5LdvsqKyvNL37xi5jvP/fccyYvLy/iC/bee+81RUVFprm52ZW2O+HW8fvzn/9sCgoKzMGDB0Ovpev4jRw50lx55ZWh5y0tLaZ///6murradvr//M//NOecc07Ea6NGjTI//elPjTHOPpOplOj2RTt06JDp0aOHWbp0aei1KVOmmPPOO8/tpiYt0W3s6Ps1247hXXfdZXr06GH27t0bei3TjqHFyffAzJkzzcknnxzx2kUXXWQmTJgQet7ZfeYE3UadUFdXp/r6eo0fPz70WnFxsUaNGqXa2lpJUm1trXr27KkzzjgjNM348eOVl5enN998M6XtdaMta9eu1bp163T55Ze3e+/KK69U7969NXLkSD344IMyKa5/2Jnt++Mf/6jevXtr6NChqqqq0v79+yOWO2zYMJWUlIRemzBhgpqamvTBBx+4vyExuHUuNTY2qqioSEccEXlf1lQfvwMHDmjt2rURn5+8vDyNHz8+9PmJVltbGzG9FDwW1vROPpOpksz2Rdu/f78OHjyoXr16Rby+cuVK9e3bV4MHD9bPf/5zff7556623alkt3Hv3r0aOHCgysrKdN5550V8jrLtGD7wwAO6+OKL9ZWvfCXi9Uw5honq6DPoxj5zIuvuKp1K9fX1khTxo2Y9t96rr69X3759I94/4ogj1KtXr9A0qeJGWx544AGdeOKJGj16dMTrt9xyi7797W+re/fueuGFF/Tf//3f2rt3r6655hrX2t+RZLfvBz/4gQYOHKj+/fvr3Xff1S9/+Utt2rRJjz32WGi5dsfYei9V3Dh+n332mebNm6crrrgi4vV0HL/PPvtMLS0ttvt248aNtvPEOhbhnzfrtVjTpEoy2xftl7/8pfr37x/xQzBx4kRdcMEFKi8v14cffqhf/epXmjRpkmpra5Wfn+/qNnQkmW0cPHiwHnzwQZ1yyilqbGzUwoULNXr0aH3wwQc65phjsuoYrlmzRu+//74eeOCBiNcz6RgmKtZnsKmpSV988YX+/e9/d/q8dyLrg5dZs2bp9ttvjzvNhg0bNGTIkBS1yH1Ot7GzvvjiCy1btkyzZ89u9174a6eeeqr27dunBQsWuPLj5/X2hf+QDxs2TP369dNZZ52lDz/8UMcdd1zSy3UqVcevqalJ55xzjk466STdfPPNEe95efyQnNtuu03Lly/XypUrIxJaL7744tC/hw0bplNOOUXHHXecVq5cqbPOOisdTU1IRUWFKioqQs9Hjx6tE088Uffff7/mzZuXxpa574EHHtCwYcM0cuTIiNf9fgwzQdYHLzNmzNDUqVPjTjNo0KCkll1aWipJamhoUL9+/UKvNzQ0aMSIEaFpPv3004j5Dh06pF27doXm7yyn29jZtjz66KPav3+/Lr300g6nHTVqlObNm6fm5uZO37wrVdtnGTVqlCRp8+bNOu6441RaWtouU76hoUGSXDmGqdi+PXv2aOLEierRo4cef/xxdenSJe70bh6/WHr37q38/PzQvrQ0NDTE3J7S0tK40zv5TKZKMttnWbhwoW677Ta99NJLOuWUU+JOO2jQIPXu3VubN29O+Q9fZ7bR0qVLF5166qnavHmzpOw5hvv27dPy5ct1yy23dLiedB7DRMX6DBYVFalbt27Kz8/v9DnhiGvZM1kk0YTdhQsXhl5rbGy0Tdh96623QtP87W9/S2vCbrJtqaysbDdKJZZf//rX5qijjkq6rclwa1+/+uqrRpL5xz/+YYxpS9gNz5S///77TVFRkfnyyy/d24AOJLt9jY2N5swzzzSVlZVm3759jtaVquM3cuRIc9VVV4Wet7S0mKOPPjpuwu53v/vdiNcqKiraJezG+0ymUqLbZ4wxt99+uykqKjK1tbWO1rFt2zYTCATMk08+2en2JiOZbQx36NAhM3jwYHPdddcZY7LjGBoT/B0pLCw0n332WYfrSPcxtMhhwu7QoUMjXrvkkkvaJex25pxw1FbXlpQFPvroI/POO++EhgK/88475p133okYEjx48GDz2GOPhZ7fdtttpmfPnubJJ5807777rjnvvPNsh0qfeuqp5s033zSvvvqqOeGEE9I6VDpeWz7++GMzePBg8+abb0bM989//tMEAgHz17/+td0yn3rqKfO73/3OvPfee+af//yn+e1vf2u6d+9ubrrpJs+3J1qi27d582Zzyy23mLfeesvU1dWZJ5980gwaNMiMGTMmNI81VPrss88269atM88//7zp06dP2oZKJ7J9jY2NZtSoUWbYsGFm8+bNEUMzDx06ZIxJ7/Fbvny5KSwsNEuWLDHr1683V1xxhenZs2doZNePfvQjM2vWrND0r732mjniiCPMwoULzYYNG8ycOXNsh0p39JlMlUS377bbbjMFBQXm0UcfjThW1nfQnj17zP/7f//P1NbWmrq6OvPSSy+Z0047zZxwwgkpDaQ7s41z5841f/vb38yHH35o1q5day6++GLTtWtX88EHH4Sm8fMxtHzzm980F110UbvXM+0Y7tmzJ/RbJ8nceeed5p133jEfffSRMcaYWbNmmR/96Eeh6a2h0tdff73ZsGGDWbRoke1Q6Xj7zA0EL2GmTJliJLV71NTUhKbR4XoYltbWVjN79mxTUlJiCgsLzVlnnWU2bdoUsdzPP//cXHLJJebII480RUVF5rLLLosIiFKpo7bU1dW122ZjjKmqqjJlZWWmpaWl3TL/+te/mhEjRpgjjzzSfOUrXzHDhw839913n+20Xkt0+7Zu3WrGjBljevXqZQoLC83xxx9vrr/++og6L8YYs2XLFjNp0iTTrVs307t3bzNjxoyIocapkuj21dTU2J7TkkxdXZ0xJv3H75577jEDBgwwBQUFZuTIkeaNN94IvVdZWWmmTJkSMf2f//xn87Wvfc0UFBSYk08+2Tz77LMR7zv5TKZSIts3cOBA22M1Z84cY4wx+/fvN2effbbp06eP6dKlixk4cKCZNm2aqz8KyUhkG6+99trQtCUlJeY73/mOefvttyOW5+djaIwxGzduNJLMCy+80G5ZmXYMY31HWNs0ZcoUU1lZ2W6eESNGmIKCAjNo0KCI30RLvH3mhoAxKR7PCgAA0AnUeQEAAL5C8AIAAHyF4AUAAPgKwQsAAPAVghcAAOArBC8AAMBXCF4AAICvELwAAABfIXgBAAC+QvACAAB8heAFAAD4yv8HHXCGI2jrUIgAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# push training data points through model\n", "train_mean_f = model(train_x).loc.data.cpu()\n", "# plot training data with y being -1/1 valued\n", "plt.plot(train_x.squeeze(-1).cpu(), train_y.mul(2.).sub(1.).cpu(), 'o')\n", "# plot mean gaussian process posterior mean evaluated at training data\n", "plt.plot(train_x.squeeze(-1).cpu(), train_mean_f.cpu(), 'x')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected the Q-Exponential Process posterior mean (plotted in orange) gives confident predictions in the regions\n", "where the correct label is unambiguous (e.g. for x ~ 0.5) and gives unconfident predictions in regions where\n", "the correct label is ambiguous (e.g. x ~ 0.0).\n", "\n", "We compute the negative log likelihood (NLL) and classification accuracy on the held-out test data." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test NLL: 0.4710\n", "Test Acc: 0.8000\n" ] } ], "source": [ "model.eval()\n", "likelihood.eval()\n", "with torch.no_grad():\n", " nlls = -likelihood.log_marginal(test_y, model(test_x))\n", " acc = (likelihood(model(test_x)).probs.gt(0.5) == test_y.bool()).float().mean()\n", "print('Test NLL: {:.4f}'.format(nlls.mean()))\n", "print('Test Acc: {:.4f}'.format(acc.mean()))" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 4 }