Saving and Loading Models

In this bite-sized notebook, we’ll go over how to save and load models. In general, the process is the same as for any PyTorch module.

[1]:
import math
import torch
import qpytorch
from matplotlib import pyplot as plt

Saving a Simple Model

First, we define a QEP Model that we’d like to save. The model used below is the same as the model from our Simple QEP Regression tutorial.

[2]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2
[4]:
# We will use the simplest form of QEP model, exact inference
POWER = 1.0
class ExactQEPModel(qpytorch.models.ExactQEP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)
        self.power = torch.tensor(POWER)
        self.mean_module = qpytorch.means.ConstantMean()
        self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)

# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = ExactQEPModel(train_x, train_y, likelihood)

Change Model State

To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on Initializing Hyperparameters.

[5]:
model.covar_module.outputscale = 1.2
model.covar_module.base_kernel.lengthscale = 2.2

Getting Model State

To get the full state of a GPyTorch model, simply call state_dict as you would on any PyTorch model. Note that the state dict contains raw parameter values. This is because these are the actual torch.nn.Parameters that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this.

[6]:
model.state_dict()
[6]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf))])

Saving Model State

The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:

[7]:
torch.save(model.state_dict(), 'model_state.pth')

Loading Model State

Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.

[8]:
state_dict = torch.load('model_state.pth')
model = ExactQEPModel(train_x, train_y, likelihood)  # Create a new QEP model

model.load_state_dict(state_dict)
[8]:
<All keys matched successfully>
[9]:
model.state_dict()
[9]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf))])

A More Complex Example

Next we demonstrate this same principle on a more complex exact QEP where we have a simple feed forward neural network feature extractor as part of the model.

[11]:
class QEPWithNNFeatureExtractor(qpytorch.models.ExactQEP):
    def __init__(self, train_x, train_y, likelihood):
        super(QEPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)
        self.power = torch.tensor(POWER)
        self.mean_module = qpytorch.means.ConstantMean()
        self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())

        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Linear(1, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
            torch.nn.Linear(2, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
        )

    def forward(self, x):
        x = self.feature_extractor(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)

# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)

Getting Model State

In the next cell, we once again print the model state via model.state_dict(). As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward.

[12]:
model.state_dict()
[12]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf)),
             ('feature_extractor.0.weight',
              tensor([[-0.1177],
                      [ 0.6034]])),
             ('feature_extractor.0.bias', tensor([0.3458, 0.1124])),
             ('feature_extractor.1.weight', tensor([1., 1.])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight',
              tensor([[ 0.6297,  0.2143],
                      [-0.5057, -0.6424]])),
             ('feature_extractor.3.bias', tensor([-0.5611,  0.5896])),
             ('feature_extractor.4.weight', tensor([1., 1.])),
             ('feature_extractor.4.bias', tensor([0., 0.])),
             ('feature_extractor.4.running_mean', tensor([0., 0.])),
             ('feature_extractor.4.running_var', tensor([1., 1.])),
             ('feature_extractor.4.num_batches_tracked', tensor(0))])
[13]:
torch.save(model.state_dict(), 'my_qep_with_nn_model.pth')
state_dict = torch.load('my_qep_with_nn_model.pth')
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)
model.load_state_dict(state_dict)
[13]:
<All keys matched successfully>
[14]:
model.state_dict()
[14]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('likelihood.noise_covar.raw_noise_constraint.lower_bound',
              tensor(1.0000e-04)),
             ('likelihood.noise_covar.raw_noise_constraint.upper_bound',
              tensor(inf)),
             ('mean_module.raw_constant', tensor(0.)),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
              tensor(inf)),
             ('covar_module.raw_outputscale_constraint.lower_bound',
              tensor(0.)),
             ('covar_module.raw_outputscale_constraint.upper_bound',
              tensor(inf)),
             ('feature_extractor.0.weight',
              tensor([[-0.1177],
                      [ 0.6034]])),
             ('feature_extractor.0.bias', tensor([0.3458, 0.1124])),
             ('feature_extractor.1.weight', tensor([1., 1.])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight',
              tensor([[ 0.6297,  0.2143],
                      [-0.5057, -0.6424]])),
             ('feature_extractor.3.bias', tensor([-0.5611,  0.5896])),
             ('feature_extractor.4.weight', tensor([1., 1.])),
             ('feature_extractor.4.bias', tensor([0., 0.])),
             ('feature_extractor.4.running_mean', tensor([0., 0.])),
             ('feature_extractor.4.running_var', tensor([1., 1.])),
             ('feature_extractor.4.num_batches_tracked', tensor(0))])