Source code for qpytorch.distributions.power

#!/usr/bin/env python3

from __future__ import annotations

from typing import Optional

import torch

from ..module import Module
from ..constraints import Interval, Positive
from ..priors import Prior

[docs]class Power(Module): """ Constructs a power parameter for the (multivariate) q-exponential distribution. See :class:`qpytorch.distributions.QExponential` or :class:`qpytorch.distributions.MultivariateQExponential` for description of the power parameter. .. note:: This object works similarly as a hyperparameter of kernel, which can be imposed with a prior and optimized over. :param power_init: initial value of power parameter of qep distribution. (Default: 1.0) :param power_constraint: Set this if you want to apply a constraint to the power parameter. (Default: :class:`~qpytorch.constraints.Positive`.) :param power_prior: Set this if you want to apply a prior to the power parameter. (Default: `None`.) :ivar torch.Size shape: The dimension of the power object. :ivar torch.Tensor power: The power parameter. The size/shape is the same as the `power_init` argument. :ivar torch.Tensor data: The data of the power object in :obj:`torch.tensor` format. Example: >>> power_init = torch.tensor(1.0) >>> power_prior = qpytorch.priors.GammaPrior(4.0, 2.0) >>> power = qpytorch.distributions.Power(power_init, power_prior=power_prior) """ def __init__( self, power_init: torch.Tensor = torch.tensor(1.0), power_constraint: Optional[Interval] = None, power_prior: Optional[Prior] = None ): super(Power, self).__init__() if power_constraint is None: power_constraint = Positive() # set parameter self.register_parameter( name="raw_power", parameter=torch.nn.Parameter(power_constraint.inverse_transform(power_init)) ) self.shape = self.raw_power.shape # set constraint self.register_constraint("raw_power", power_constraint) # set prior if power_prior is not None: if not isinstance(power_prior, Prior): raise TypeError("Expected qpytorch.priors.Prior but got " + type(power_prior).__name__) self.register_prior("power_prior", power_prior, self._power_param, self._power_closure) def _power_param(self, q: Power) -> torch.Tensor: # Used by the raw_power return q.power def _power_closure(self, q: Power, v: torch.Tensor) -> torch.Tensor: # Used by the raw_power return q._set_power(v) @property def power(self) -> torch.Tensor: return self.raw_power_constraint.transform(self.raw_power) @power.setter def power(self, value: torch.Tensor) -> torch.Tensor: self._set_power(value) def _set_power(self, value: torch.Tensor): if not torch.is_tensor(value): value = torch.as_tensor(value).to(self.raw_power) self.initialize(raw_power=self.raw_power_constraint.inverse_transform(value)) @property def data(self) -> torch.Tensor: return self.power.data def __truediv__(self, other): return self.power/other def __rtruediv__(self, other): return other/self.power def __rpow__(self, other): return other**self.power def __ne__(self, other): return self.power!=other def __lt__(self, other): return self.power<other def __gt__(self, other): return self.power>other def numel(self): return self.power.numel() def size(self): return self.power.size()