{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# QEP-LVM for Regularizing the Latent Representations \n", "\n", "## Introduction \n", "\n", "In this notebook we demonstrate the QEP-LVM model class introduced in [Obite et~al, 2025](https://openreview.net/pdf?id=VOoJEQlLW5.pdf) and this [notebook](./QExponential_Process_Latent_Variable_Models_with_Stochastic_Variational_Inference.ipynb) for an introduction.\n", "\n", "We focus on illustrating the regularizing effect of QEP-LVM by parameter $q>0$ on learning latent representations using a simulated example of [Swiss roll](https://scikit-learn.org/stable/auto_examples/manifold/plot_swissroll.html) dataset. QEP-LVM tends to contract the learnt latent representation towards axes as $q$ decreases." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# Standard imports\n", "import matplotlib.pylab as plt\n", "import os\n", "import numpy as np\n", "from sklearn.datasets import make_swiss_roll\n", "import torch \n", "from torch.utils.data import TensorDataset, DataLoader\n", "import tqdm\n", "\n", "%matplotlib inline\n", "\n", "# Setting manual seed for reproducibility\n", "seed = 2024\n", "torch.manual_seed(seed)\n", "np.random.seed(seed)\n", "\n", "# this is for running the notebook in our testing framework\n", "smoke_test = ('CI' in os.environ)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Set up training data \n", "\n", "We use the canonical swiss roll data generated using `make_swiss_roll` from [scikit-learn](https://scikit-learn.org)." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n_samples = 1000\n", "sr_points, sr_color = make_swiss_roll(n_samples=n_samples, noise=0.05, random_state=0)\n", "Y, t = torch.tensor(sr_points), torch.tensor(sr_color)\n", "train_dataset = TensorDataset(Y, t, torch.arange(n_samples))\n", "batch_size = 256\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Defining the QEPLVM model\n", "\n", "Now we construct Bayesian QEP-LVM using `BayesianQEPLVM` in QPyTorch.\n", "The BayesianQEPLVM is built on top of the Sparse QEP formulation. Similar to the [SVQEP example](../04_Variational_and_Approximate_QEPs/SVQEP_Regression_CUDA.ipynb), we'll use a `CholeskyVariationalDistribution` to model $q(\\mathbf{u})$ and the standard `VariationalStrategy` as defined by Hensman et al. (2015).\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# qpytorch imports\n", "import qpytorch\n", "from qpytorch.models.qeplvm.latent_variable import *\n", "from qpytorch.models.qeplvm.bayesian_qeplvm import BayesianQEPLVM\n", "from qpytorch.means import ZeroMean\n", "from qpytorch.mlls import VariationalELBO\n", "from qpytorch.priors import QExponentialPrior\n", "from qpytorch.likelihoods import QExponentialLikelihood\n", "from qpytorch.variational import VariationalStrategy\n", "from qpytorch.variational import CholeskyVariationalDistribution\n", "from qpytorch.kernels import ScaleKernel, RBFKernel\n", "from qpytorch.distributions import MultivariateQExponential\n", "\n", "\n", "def _pca(Y, latent_dim):\n", " U, S, V = torch.pca_lowrank(Y, q = latent_dim)\n", " return torch.matmul(Y, V[:,:latent_dim])\n", "\n", "class bQEPLVM(BayesianQEPLVM):\n", " def __init__(self, power, n, data_dim, latent_dim, n_inducing, pca=False):\n", " self.power = torch.tensor(power)\n", " self.n = n\n", " self.batch_shape = torch.Size([data_dim])\n", "\n", " # Locations Z_{d} corresponding to u_{d}, they can be randomly initialized or\n", " # regularly placed with shape (D x n_inducing x latent_dim).\n", " self.inducing_inputs = torch.randn(data_dim, n_inducing, latent_dim)\n", "\n", " # Sparse Variational Formulation (inducing variables initialised as randn)\n", " q_u = CholeskyVariationalDistribution(n_inducing, batch_shape=self.batch_shape, power=self.power)\n", " q_f = VariationalStrategy(self, self.inducing_inputs, q_u, learn_inducing_locations=True)\n", "\n", " # Define prior for X\n", " X_prior_mean = torch.zeros(n, latent_dim) # shape: N x Q\n", " prior_x = QExponentialPrior(X_prior_mean, torch.ones_like(X_prior_mean), power=self.power)\n", "\n", " # Initialise X with PCA or randn\n", " if pca == True:\n", " X_init = torch.nn.Parameter(_pca(Y.float(), latent_dim)) # Initialise X to PCA\n", " else:\n", " X_init = torch.nn.Parameter(torch.randn(n, latent_dim))\n", "\n", " # LatentVariable (c)\n", " X = VariationalLatentVariable(n, data_dim, latent_dim, X_init, prior_x)\n", "\n", " # For (a) or (b) change to below:\n", " # X = PointLatentVariable(n, latent_dim, X_init)\n", " # X = MAPLatentVariable(n, latent_dim, X_init, prior_x)\n", "\n", " super().__init__(X, q_f)\n", "\n", " # Kernel (acting on latent dimensions)\n", " self.mean_module = ZeroMean(ard_num_dims=latent_dim)\n", " self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=latent_dim))\n", "\n", " def forward(self, X):\n", " mean_x = self.mean_module(X)\n", " covar_x = self.covar_module(X)\n", " dist = MultivariateQExponential(mean_x, covar_x, power=self.power)\n", " return dist\n", "\n", " def _get_batch_idx(self, batch_size, seed=None):\n", " valid_indices = np.arange(self.n)\n", " batch_indices = np.random.choice(valid_indices, size=batch_size, replace=False) if seed is None else \\\n", " np.random.default_rng(seed).choice(valid_indices, size=batch_size, replace=False)\n", " return np.sort(batch_indices)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the model\n", "\n", "While we need to specify the dimensionality of the latent variables at the outset, one of the advantages of the Bayesian framework is that by using a ARD kernel we can prune dimensions corresponding to small inverse lengthscales.\n", "We train multiple QEP-LVM models with different $q$s to illustrate the regularization effect." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "N = len(Y)\n", "data_dim = Y.shape[1]\n", "latent_dim = data_dim\n", "n_inducing = 25\n", "pca = False\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We use mini-batch training for scalability where only a subset of the local variaitonal params are optimised in each iteration. " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "313820624edd43988ee3eb1b276a9732", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Epoch: 0%| | 0/1000 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Initialize plots\n", "fig, ax = plt.subplots(2, 4, figsize=(14, 10))\n", "idx2plot = np.random.default_rng(seed).choice(n_samples, size=500, replace=False)\n", "colors = t[idx2plot]\n", "\n", "# visualize 3d dataset\n", "ax[0,0].remove()\n", "ax[0,0]=fig.add_subplot(241, projection=\"3d\")\n", "ax[0,0].scatter(sr_points[:, 0], sr_points[:, 1], sr_points[:, 2], c=sr_color, alpha=0.8)\n", "ax[0,0].set_title('Swiss Roll')\n", "\n", "# PCA\n", "X = _pca(Y, 2)[idx2plot]\n", "ax[1,0].scatter(X[:, 0], X[:,1], c=colors, alpha=0.8)\n", "ax[1,0].set_title('PCA')\n", "ax[1,0].set_xlabel('Latent dim 1')\n", "ax[1,0].set_ylabel('Latent dim 2')\n", "\n", "# QEP-LVM\n", "for i, q in enumerate([1.0, 1.5, 2.0]):\n", " # obtain model\n", " model = model_list[i]\n", " model.eval()\n", " \n", " inv_lengthscale = 1 / model.covar_module.base_kernel.lengthscale\n", " values, indices = torch.topk(model.covar_module.base_kernel.lengthscale, k=2,largest=False)\n", " l1, l2 = indices.detach().numpy().flatten()[:2]\n", "\n", " idx2plot = model._get_batch_idx(500, seed)\n", " X = (model.X.q_mu if hasattr(model.X, 'q_mu') else model.X.X).detach().numpy()[idx2plot]\n", " colors = t[idx2plot]\n", " \n", " # Select index of the smallest lengthscales by examining model.covar_module.base_kernel.lengthscales \n", " ax[0, i+1].scatter(X[:, l1], X[:, l2], c=colors, alpha=0.8)\n", " ax[0, i+1].set_title('q = '+str(q)+(' (Gaussian)' if q==2 else ''))\n", " \n", " ax[1, i+1].bar(np.arange(latent_dim), height=inv_lengthscale.detach().numpy().flatten())\n", " ax[1, i+1].set_title('Inverse Lengthscale of kernel')\n", "\n", "None" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The latent space learned by QEP-LVM tends to contract towards axes as $q$ decreases. Therefore, the resulting latent representation becomes more compact with smaller $q$." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 4 }