Saving and Loading Models¶
In this bite-sized notebook, we’ll go over how to save and load models. In general, the process is the same as for any PyTorch module.
[1]:
import math
import torch
import qpytorch
from matplotlib import pyplot as plt
Saving a Simple Model¶
First, we define a QEP Model that we’d like to save. The model used below is the same as the model from our Simple QEP Regression tutorial.
[2]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2
[4]:
# We will use the simplest form of QEP model, exact inference
POWER = 1.0
class ExactQEPModel(qpytorch.models.ExactQEP):
def __init__(self, train_x, train_y, likelihood):
super(ExactQEPModel, self).__init__(train_x, train_y, likelihood)
self.power = torch.tensor(POWER)
self.mean_module = qpytorch.means.ConstantMean()
self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)
# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = ExactQEPModel(train_x, train_y, likelihood)
Change Model State¶
To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on Initializing Hyperparameters.
[5]:
model.covar_module.outputscale = 1.2
model.covar_module.base_kernel.lengthscale = 2.2
Getting Model State¶
To get the full state of a GPyTorch model, simply call state_dict as you would on any PyTorch model. Note that the state dict contains raw parameter values. This is because these are the actual torch.nn.Parameters that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this.
[6]:
model.state_dict()
[6]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
('likelihood.noise_covar.raw_noise_constraint.lower_bound',
tensor(1.0000e-04)),
('likelihood.noise_covar.raw_noise_constraint.upper_bound',
tensor(inf)),
('mean_module.raw_constant', tensor(0.)),
('covar_module.raw_outputscale', tensor(0.8416)),
('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
tensor(0.)),
('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
tensor(inf)),
('covar_module.raw_outputscale_constraint.lower_bound',
tensor(0.)),
('covar_module.raw_outputscale_constraint.upper_bound',
tensor(inf))])
Saving Model State¶
The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:
[7]:
torch.save(model.state_dict(), 'model_state.pth')
Loading Model State¶
Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.
[8]:
state_dict = torch.load('model_state.pth')
model = ExactQEPModel(train_x, train_y, likelihood) # Create a new QEP model
model.load_state_dict(state_dict)
[8]:
<All keys matched successfully>
[9]:
model.state_dict()
[9]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
('likelihood.noise_covar.raw_noise_constraint.lower_bound',
tensor(1.0000e-04)),
('likelihood.noise_covar.raw_noise_constraint.upper_bound',
tensor(inf)),
('mean_module.raw_constant', tensor(0.)),
('covar_module.raw_outputscale', tensor(0.8416)),
('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]])),
('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
tensor(0.)),
('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
tensor(inf)),
('covar_module.raw_outputscale_constraint.lower_bound',
tensor(0.)),
('covar_module.raw_outputscale_constraint.upper_bound',
tensor(inf))])
A More Complex Example¶
Next we demonstrate this same principle on a more complex exact QEP where we have a simple feed forward neural network feature extractor as part of the model.
[11]:
class QEPWithNNFeatureExtractor(qpytorch.models.ExactQEP):
def __init__(self, train_x, train_y, likelihood):
super(QEPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)
self.power = torch.tensor(POWER)
self.mean_module = qpytorch.means.ConstantMean()
self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
self.feature_extractor = torch.nn.Sequential(
torch.nn.Linear(1, 2),
torch.nn.BatchNorm1d(2),
torch.nn.ReLU(),
torch.nn.Linear(2, 2),
torch.nn.BatchNorm1d(2),
torch.nn.ReLU(),
)
def forward(self, x):
x = self.feature_extractor(x)
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return qpytorch.distributions.MultivariateQExponential(mean_x, covar_x, power=self.power)
# initialize likelihood and model
likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(POWER))
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)
Getting Model State¶
In the next cell, we once again print the model state via model.state_dict(). As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward.
[12]:
model.state_dict()
[12]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
('likelihood.noise_covar.raw_noise_constraint.lower_bound',
tensor(1.0000e-04)),
('likelihood.noise_covar.raw_noise_constraint.upper_bound',
tensor(inf)),
('mean_module.raw_constant', tensor(0.)),
('covar_module.raw_outputscale', tensor(0.)),
('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
tensor(0.)),
('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
tensor(inf)),
('covar_module.raw_outputscale_constraint.lower_bound',
tensor(0.)),
('covar_module.raw_outputscale_constraint.upper_bound',
tensor(inf)),
('feature_extractor.0.weight',
tensor([[-0.1177],
[ 0.6034]])),
('feature_extractor.0.bias', tensor([0.3458, 0.1124])),
('feature_extractor.1.weight', tensor([1., 1.])),
('feature_extractor.1.bias', tensor([0., 0.])),
('feature_extractor.1.running_mean', tensor([0., 0.])),
('feature_extractor.1.running_var', tensor([1., 1.])),
('feature_extractor.1.num_batches_tracked', tensor(0)),
('feature_extractor.3.weight',
tensor([[ 0.6297, 0.2143],
[-0.5057, -0.6424]])),
('feature_extractor.3.bias', tensor([-0.5611, 0.5896])),
('feature_extractor.4.weight', tensor([1., 1.])),
('feature_extractor.4.bias', tensor([0., 0.])),
('feature_extractor.4.running_mean', tensor([0., 0.])),
('feature_extractor.4.running_var', tensor([1., 1.])),
('feature_extractor.4.num_batches_tracked', tensor(0))])
[13]:
torch.save(model.state_dict(), 'my_qep_with_nn_model.pth')
state_dict = torch.load('my_qep_with_nn_model.pth')
model = QEPWithNNFeatureExtractor(train_x, train_y, likelihood)
model.load_state_dict(state_dict)
[13]:
<All keys matched successfully>
[14]:
model.state_dict()
[14]:
OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
('likelihood.noise_covar.raw_noise_constraint.lower_bound',
tensor(1.0000e-04)),
('likelihood.noise_covar.raw_noise_constraint.upper_bound',
tensor(inf)),
('mean_module.raw_constant', tensor(0.)),
('covar_module.raw_outputscale', tensor(0.)),
('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
('covar_module.base_kernel.raw_lengthscale_constraint.lower_bound',
tensor(0.)),
('covar_module.base_kernel.raw_lengthscale_constraint.upper_bound',
tensor(inf)),
('covar_module.raw_outputscale_constraint.lower_bound',
tensor(0.)),
('covar_module.raw_outputscale_constraint.upper_bound',
tensor(inf)),
('feature_extractor.0.weight',
tensor([[-0.1177],
[ 0.6034]])),
('feature_extractor.0.bias', tensor([0.3458, 0.1124])),
('feature_extractor.1.weight', tensor([1., 1.])),
('feature_extractor.1.bias', tensor([0., 0.])),
('feature_extractor.1.running_mean', tensor([0., 0.])),
('feature_extractor.1.running_var', tensor([1., 1.])),
('feature_extractor.1.num_batches_tracked', tensor(0)),
('feature_extractor.3.weight',
tensor([[ 0.6297, 0.2143],
[-0.5057, -0.6424]])),
('feature_extractor.3.bias', tensor([-0.5611, 0.5896])),
('feature_extractor.4.weight', tensor([1., 1.])),
('feature_extractor.4.bias', tensor([0., 0.])),
('feature_extractor.4.running_mean', tensor([0., 0.])),
('feature_extractor.4.running_var', tensor([1., 1.])),
('feature_extractor.4.num_batches_tracked', tensor(0))])