#!/usr/bin/env python3
from __future__ import annotations
import warnings
from collections.abc import Iterable
from copy import deepcopy
import torch
from torch import Tensor
from .. import settings
from ..distributions import MultitaskMultivariateQExponential, MultivariateQExponential
from ..likelihoods import _QExponentialLikelihoodBase
from gpytorch.utils.generic import length_safe_zip
from ..utils.warnings import QEPInputWarning
from .exact_prediction_strategies import prediction_strategy
from .qep import QEP
[docs]class ExactQEP(QEP):
r"""
The base class for any Q-Exponential process latent function to be used in conjunction
with exact inference.
:param torch.Tensor train_inputs: (size n x d) The training features :math:`\mathbf X`.
:param torch.Tensor train_targets: (size n) The training targets :math:`\mathbf y`.
:param ~qpytorch.likelihoods.QExponentialLikelihood likelihood: The Q-Exponential likelihood that defines
the observational distribution. Since we're using exact inference, the likelihood must be Q-Exponential.
The :meth:`forward` function should describe how to compute the prior latent distribution
on a given input. Typically, this will involve a mean and kernel function.
The result must be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
Calling this model will return the posterior of the latent Q-Exponential process when conditioned
on the training data. The output will be a :obj:`~qpytorch.distributions.MultivariateQExponential`.
Example:
>>> class MyQEP(qpytorch.models.ExactQEP):
>>> def __init__(self, train_x, train_y, likelihood):
>>> super().__init__(train_x, train_y, likelihood)
>>> self.mean_module = qpytorch.means.ZeroMean()
>>> self.covar_module = qpytorch.kernels.ScaleKernel(qpytorch.kernels.RBFKernel())
>>>
>>> def forward(self, x):
>>> mean = self.mean_module(x)
>>> covar = self.covar_module(x)
>>> return qpytorch.distributions.MultivariateQExponential(mean, covar, self.likelihood.power)
>>>
>>> # train_x = ...; train_y = ...
>>> likelihood = qpytorch.likelihoods.QExponentialLikelihood(power=torch.tensor(1.0))
>>> model = MyQEP(train_x, train_y, likelihood)
>>>
>>> # test_x = ...;
>>> model(test_x) # Returns the QEP latent function at test_x
>>> likelihood(model(test_x)) # Returns the (approximate) predictive posterior distribution at test_x
"""
def __init__(
self,
train_inputs: Tensor | Iterable[Tensor] | None,
train_targets: Tensor | None,
likelihood: _QExponentialLikelihoodBase,
):
if train_inputs is not None and isinstance(train_inputs, Tensor):
train_inputs = (train_inputs,)
if train_inputs is not None and not all(isinstance(train_input, Tensor) for train_input in train_inputs):
raise RuntimeError("Train inputs must be a tensor, or a list/tuple of tensors")
if not isinstance(likelihood, _QExponentialLikelihoodBase):
raise RuntimeError("ExactQEP can only handle Q-Exponential likelihoods")
super().__init__()
if train_inputs is not None:
self.train_inputs = tuple(tri.unsqueeze(-1) if tri.ndimension() == 1 else tri for tri in train_inputs)
self.train_targets = train_targets
else:
self.train_inputs = None
self.train_targets = None
self.likelihood = likelihood
self.prediction_strategy = None
@property
def train_targets(self) -> tuple[Tensor] | None:
return self._train_targets
@train_targets.setter
def train_targets(self, value: Tensor | None) -> None:
object.__setattr__(self, "_train_targets", value)
def _apply(self, fn):
if self.train_inputs is not None:
self.train_inputs = tuple(fn(train_input) for train_input in self.train_inputs)
self.train_targets = fn(self.train_targets)
return super()._apply(fn)
def _clear_cache(self) -> None:
# The precomputed caches from test time live in prediction_strategy
self.prediction_strategy = None
[docs] def local_load_samples(self, samples_dict, memo, prefix):
"""
Replace the model's learned hyperparameters with samples from a posterior distribution.
"""
# Pyro always puts the samples in the first batch dimension
num_samples = next(iter(samples_dict.values())).size(0)
self.train_inputs = tuple(tri.unsqueeze(0).expand(num_samples, *tri.shape) for tri in self.train_inputs)
self.train_targets = self.train_targets.unsqueeze(0).expand(num_samples, *self.train_targets.shape)
super().local_load_samples(samples_dict, memo, prefix)
[docs] def set_train_data(
self, inputs: Tensor | Iterable[Tensor] | None = None, targets: Tensor | None = None, strict: bool = True
) -> None:
"""
Set training data (does not re-fit model hyper-parameters).
:param inputs: The new training inputs.
:param targets: The new training targets.
:param strict: If `True`, the new inputs and targets must have the same shape,
dtype, and device as the current inputs and targets. Otherwise, any
shape/dtype/device are allowed.
"""
if inputs is not None:
if isinstance(inputs, Tensor):
inputs = (inputs,)
inputs = tuple(input_.unsqueeze(-1) if input_.ndimension() == 1 else input_ for input_ in inputs)
if strict:
for input_, t_input in length_safe_zip(inputs, self.train_inputs or (None,)):
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(t_input, attr, None)
found_attr = getattr(input_, attr, None)
if expected_attr != found_attr:
msg = "Cannot modify {attr} of inputs (expected {e_attr}, found {f_attr})."
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
raise RuntimeError(msg)
self.train_inputs = inputs
if targets is not None:
if strict:
for attr in {"shape", "dtype", "device"}:
expected_attr = getattr(self.train_targets, attr, None)
found_attr = getattr(targets, attr, None)
if expected_attr != found_attr:
msg = "Cannot modify {attr} of targets (expected {e_attr}, found {f_attr})."
msg = msg.format(attr=attr, e_attr=expected_attr, f_attr=found_attr)
raise RuntimeError(msg)
self.train_targets = targets
self.prediction_strategy = None
[docs] def get_fantasy_model(self, inputs, targets, **kwargs):
"""
Returns a new QEP model that incorporates the specified inputs and targets as new training data.
Using this method is more efficient than updating with `set_train_data` when the number of inputs is relatively
small, because any computed test-time caches will be updated in linear time rather than computed from scratch.
.. note::
If `targets` is a batch (e.g. `b x m`), then the QEP returned from this method will be a batch mode QEP.
If `inputs` is of the same (or lesser) dimension as `targets`, then it is assumed that the fantasy points
are the same for each target batch.
:param torch.Tensor inputs: (`b1 x ... x bk x m x d` or `f x b1 x ... x bk x m x d`) Locations of fantasy
observations.
:param torch.Tensor targets: (`b1 x ... x bk x m` or `f x b1 x ... x bk x m`) Labels of fantasy observations.
:return: An `ExactQEP` model with `n + m` training examples, where the `m` fantasy examples have been added
and all test-time caches have been updated.
:rtype: ~qpytorch.models.ExactEQP
"""
if self.prediction_strategy is None:
raise RuntimeError(
"Fantasy observations can only be added after making predictions with a model so that "
"all test independent caches exist. Call the model on some data first!"
)
model_batch_shape = self.train_inputs[0].shape[:-2]
if not isinstance(inputs, list):
inputs = [inputs]
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in inputs]
if not isinstance(self.prediction_strategy.train_prior_dist, MultitaskMultivariateQExponential):
data_dim_start = -1
else:
data_dim_start = -2
target_batch_shape = targets.shape[:data_dim_start]
input_batch_shape = inputs[0].shape[:-2]
tbdim, ibdim = len(target_batch_shape), len(input_batch_shape)
if not (tbdim == ibdim + 1 or tbdim == ibdim):
raise RuntimeError(
f"Unsupported batch shapes: The target batch shape ({target_batch_shape}) must have either the "
f"same dimension as or one more dimension than the input batch shape ({input_batch_shape})"
)
# Check whether we can properly broadcast batch dimensions
try:
torch.broadcast_shapes(model_batch_shape, target_batch_shape)
except RuntimeError:
raise RuntimeError(
f"Model batch shape ({model_batch_shape}) and target batch shape "
f"({target_batch_shape}) are not broadcastable."
)
if len(model_batch_shape) > len(input_batch_shape):
input_batch_shape = model_batch_shape
if len(model_batch_shape) > len(target_batch_shape):
target_batch_shape = model_batch_shape
# If input has no fantasy batch dimension but target does, we can save memory and computation by not
# computing the covariance for each element of the batch. Therefore we don't expand the inputs to the
# size of the fantasy model here - this is done below, after the evaluation and fast fantasy update
train_inputs = [tin.expand(input_batch_shape + tin.shape[-2:]) for tin in self.train_inputs]
train_targets = self.train_targets.expand(target_batch_shape + self.train_targets.shape[data_dim_start:])
full_inputs = [
torch.cat(
[train_input, input.expand(input_batch_shape + input.shape[-2:])],
dim=-2,
)
for train_input, input in length_safe_zip(train_inputs, inputs)
]
full_targets = torch.cat(
[train_targets, targets.expand(target_batch_shape + targets.shape[data_dim_start:])], dim=data_dim_start
)
try:
fantasy_kwargs = {"noise": kwargs.pop("noise")}
except KeyError:
fantasy_kwargs = {}
full_output = super().__call__(*full_inputs, **kwargs)
# Copy model without copying training data or prediction strategy (since we'll overwrite those)
old_pred_strat = self.prediction_strategy
old_train_inputs = self.train_inputs
old_train_targets = self.train_targets
old_likelihood = self.likelihood
self.prediction_strategy = None
self.train_inputs = None
self.train_targets = None
self.likelihood = None
new_model = deepcopy(self)
self.prediction_strategy = old_pred_strat
self.train_inputs = old_train_inputs
self.train_targets = old_train_targets
self.likelihood = old_likelihood
new_model.likelihood = old_likelihood.get_fantasy_likelihood(**fantasy_kwargs)
new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
inputs, targets, full_inputs, full_targets, full_output, **fantasy_kwargs
)
# if the fantasies are at the same points, we need to expand the inputs for the new model
if tbdim == ibdim + 1:
new_model.train_inputs = [fi.expand(target_batch_shape + fi.shape[-2:]) for fi in full_inputs]
else:
new_model.train_inputs = full_inputs
new_model.train_targets = full_targets
return new_model
def __call__(self, *args, **kwargs):
train_inputs = list(self.train_inputs) if self.train_inputs is not None else []
inputs = [i.unsqueeze(-1) if i.ndimension() == 1 else i for i in args]
# Training mode: optimizing
if self.training:
if self.train_inputs is None:
raise RuntimeError(
"train_inputs cannot be None in training mode. "
"Call .eval() for prior predictions, or call .set_train_data() to add training data."
)
if settings.debug.on():
if not all(
torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)
):
raise RuntimeError("You must train on the training inputs!")
res = super().__call__(*inputs, **kwargs)
return res
# Prior mode
elif settings.prior_mode.on() or self.train_inputs is None or self.train_targets is None:
full_inputs = args
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateQExponential):
raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
return full_output
# Posterior mode
else:
if settings.debug.on():
if all(torch.equal(train_input, input) for train_input, input in length_safe_zip(train_inputs, inputs)):
warnings.warn(
"The input matches the stored training data. Did you forget to call model.train()?",
QEPInputWarning,
)
# Get the terms that only depend on training data
if self.prediction_strategy is None:
train_output = super().__call__(*train_inputs, **kwargs)
# Create the prediction strategy for
self.prediction_strategy = prediction_strategy(
train_inputs=train_inputs,
train_prior_dist=train_output,
train_labels=self.train_targets,
likelihood=self.likelihood,
)
# Concatenate the input to the training input
full_inputs = []
batch_shape = train_inputs[0].shape[:-2]
for train_input, input in length_safe_zip(train_inputs, inputs):
# Make sure the batch shapes agree for training/test data
if batch_shape != train_input.shape[:-2]:
batch_shape = torch.broadcast_shapes(batch_shape, train_input.shape[:-2])
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
if batch_shape != input.shape[:-2]:
batch_shape = torch.broadcast_shapes(batch_shape, input.shape[:-2])
train_input = train_input.expand(*batch_shape, *train_input.shape[-2:])
input = input.expand(*batch_shape, *input.shape[-2:])
full_inputs.append(torch.cat([train_input, input], dim=-2))
# Get the joint distribution for training/test data
full_output = super().__call__(*full_inputs, **kwargs)
if settings.debug().on():
if not isinstance(full_output, MultivariateQExponential):
raise RuntimeError("ExactQEP.forward must return a MultivariateQExponential")
full_mean, full_covar = full_output.loc, full_output.lazy_covariance_matrix
# Determine the shape of the joint distribution
batch_shape = full_output.batch_shape
joint_shape = full_output.event_shape
tasks_shape = joint_shape[1:] # For multitask learning
test_shape = torch.Size([joint_shape[0] - self.prediction_strategy.train_shape[0], *tasks_shape])
# Make the prediction
with settings.cg_tolerance(settings.eval_cg_tolerance.value()):
(
predictive_mean,
predictive_covar,
) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
# Reshape predictive mean to match the appropriate event shape
predictive_mean = predictive_mean.view(*batch_shape, *test_shape).contiguous()
return full_output.__class__(predictive_mean, predictive_covar, power=full_output.power)