• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# mypy: allow-untyped-defs
2import math
3import warnings
4from numbers import Number
5from typing import Optional, Union
6
7import torch
8from torch import nan
9from torch.distributions import constraints
10from torch.distributions.exp_family import ExponentialFamily
11from torch.distributions.multivariate_normal import _precision_to_scale_tril
12from torch.distributions.utils import lazy_property
13from torch.types import _size
14
15
16__all__ = ["Wishart"]
17
18_log_2 = math.log(2)
19
20
21def _mvdigamma(x: torch.Tensor, p: int) -> torch.Tensor:
22    assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function."
23    return torch.digamma(
24        x.unsqueeze(-1)
25        - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,))
26    ).sum(-1)
27
28
29def _clamp_above_eps(x: torch.Tensor) -> torch.Tensor:
30    # We assume positive input for this function
31    return x.clamp(min=torch.finfo(x.dtype).eps)
32
33
34class Wishart(ExponentialFamily):
35    r"""
36    Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`,
37    or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`
38
39    Example:
40        >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional")
41        >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2))
42        >>> m.sample()  # Wishart distributed with mean=`df * I` and
43        >>>             # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j
44
45    Args:
46        df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1
47        covariance_matrix (Tensor): positive-definite covariance matrix
48        precision_matrix (Tensor): positive-definite precision matrix
49        scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal
50    Note:
51        Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or
52        :attr:`scale_tril` can be specified.
53        Using :attr:`scale_tril` will be more efficient: all computations internally
54        are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or
55        :attr:`precision_matrix` is passed instead, it is only used to compute
56        the corresponding lower triangular matrices using a Cholesky decomposition.
57        'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1]
58
59    **References**
60
61    [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`.
62    [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`.
63    [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`.
64    [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203.
65    [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`.
66    """
67    arg_constraints = {
68        "covariance_matrix": constraints.positive_definite,
69        "precision_matrix": constraints.positive_definite,
70        "scale_tril": constraints.lower_cholesky,
71        "df": constraints.greater_than(0),
72    }
73    support = constraints.positive_definite
74    has_rsample = True
75    _mean_carrier_measure = 0
76
77    def __init__(
78        self,
79        df: Union[torch.Tensor, Number],
80        covariance_matrix: Optional[torch.Tensor] = None,
81        precision_matrix: Optional[torch.Tensor] = None,
82        scale_tril: Optional[torch.Tensor] = None,
83        validate_args=None,
84    ):
85        assert (covariance_matrix is not None) + (scale_tril is not None) + (
86            precision_matrix is not None
87        ) == 1, "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified."
88
89        param = next(
90            p
91            for p in (covariance_matrix, precision_matrix, scale_tril)
92            if p is not None
93        )
94
95        if param.dim() < 2:
96            raise ValueError(
97                "scale_tril must be at least two-dimensional, with optional leading batch dimensions"
98            )
99
100        if isinstance(df, Number):
101            batch_shape = torch.Size(param.shape[:-2])
102            self.df = torch.tensor(df, dtype=param.dtype, device=param.device)
103        else:
104            batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape)
105            self.df = df.expand(batch_shape)
106        event_shape = param.shape[-2:]
107
108        if self.df.le(event_shape[-1] - 1).any():
109            raise ValueError(
110                f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1]-1}."
111            )
112
113        if scale_tril is not None:
114            self.scale_tril = param.expand(batch_shape + (-1, -1))
115        elif covariance_matrix is not None:
116            self.covariance_matrix = param.expand(batch_shape + (-1, -1))
117        elif precision_matrix is not None:
118            self.precision_matrix = param.expand(batch_shape + (-1, -1))
119
120        self.arg_constraints["df"] = constraints.greater_than(event_shape[-1] - 1)
121        if self.df.lt(event_shape[-1]).any():
122            warnings.warn(
123                "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim."
124            )
125
126        super().__init__(batch_shape, event_shape, validate_args=validate_args)
127        self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))]
128
129        if scale_tril is not None:
130            self._unbroadcasted_scale_tril = scale_tril
131        elif covariance_matrix is not None:
132            self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix)
133        else:  # precision_matrix is not None
134            self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix)
135
136        # Chi2 distribution is needed for Bartlett decomposition sampling
137        self._dist_chi2 = torch.distributions.chi2.Chi2(
138            df=(
139                self.df.unsqueeze(-1)
140                - torch.arange(
141                    self._event_shape[-1],
142                    dtype=self._unbroadcasted_scale_tril.dtype,
143                    device=self._unbroadcasted_scale_tril.device,
144                ).expand(batch_shape + (-1,))
145            )
146        )
147
148    def expand(self, batch_shape, _instance=None):
149        new = self._get_checked_instance(Wishart, _instance)
150        batch_shape = torch.Size(batch_shape)
151        cov_shape = batch_shape + self.event_shape
152        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
153        new.df = self.df.expand(batch_shape)
154
155        new._batch_dims = [-(x + 1) for x in range(len(batch_shape))]
156
157        if "covariance_matrix" in self.__dict__:
158            new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
159        if "scale_tril" in self.__dict__:
160            new.scale_tril = self.scale_tril.expand(cov_shape)
161        if "precision_matrix" in self.__dict__:
162            new.precision_matrix = self.precision_matrix.expand(cov_shape)
163
164        # Chi2 distribution is needed for Bartlett decomposition sampling
165        new._dist_chi2 = torch.distributions.chi2.Chi2(
166            df=(
167                new.df.unsqueeze(-1)
168                - torch.arange(
169                    self.event_shape[-1],
170                    dtype=new._unbroadcasted_scale_tril.dtype,
171                    device=new._unbroadcasted_scale_tril.device,
172                ).expand(batch_shape + (-1,))
173            )
174        )
175
176        super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False)
177        new._validate_args = self._validate_args
178        return new
179
180    @lazy_property
181    def scale_tril(self):
182        return self._unbroadcasted_scale_tril.expand(
183            self._batch_shape + self._event_shape
184        )
185
186    @lazy_property
187    def covariance_matrix(self):
188        return (
189            self._unbroadcasted_scale_tril
190            @ self._unbroadcasted_scale_tril.transpose(-2, -1)
191        ).expand(self._batch_shape + self._event_shape)
192
193    @lazy_property
194    def precision_matrix(self):
195        identity = torch.eye(
196            self._event_shape[-1],
197            device=self._unbroadcasted_scale_tril.device,
198            dtype=self._unbroadcasted_scale_tril.dtype,
199        )
200        return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand(
201            self._batch_shape + self._event_shape
202        )
203
204    @property
205    def mean(self):
206        return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix
207
208    @property
209    def mode(self):
210        factor = self.df - self.covariance_matrix.shape[-1] - 1
211        factor[factor <= 0] = nan
212        return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix
213
214    @property
215    def variance(self):
216        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
217        diag_V = V.diagonal(dim1=-2, dim2=-1)
218        return self.df.view(self._batch_shape + (1, 1)) * (
219            V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V)
220        )
221
222    def _bartlett_sampling(self, sample_shape=torch.Size()):
223        p = self._event_shape[-1]  # has singleton shape
224
225        # Implemented Sampling using Bartlett decomposition
226        noise = _clamp_above_eps(
227            self._dist_chi2.rsample(sample_shape).sqrt()
228        ).diag_embed(dim1=-2, dim2=-1)
229
230        i, j = torch.tril_indices(p, p, offset=-1)
231        noise[..., i, j] = torch.randn(
232            torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),),
233            dtype=noise.dtype,
234            device=noise.device,
235        )
236        chol = self._unbroadcasted_scale_tril @ noise
237        return chol @ chol.transpose(-2, -1)
238
239    def rsample(
240        self, sample_shape: _size = torch.Size(), max_try_correction=None
241    ) -> torch.Tensor:
242        r"""
243        .. warning::
244            In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples.
245            Several tries to correct singular samples are performed by default, but it may end up returning
246            singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`.
247            In those cases, the user should validate the samples and either fix the value of `df`
248            or adjust `max_try_correction` value for argument in `.rsample` accordingly.
249        """
250
251        if max_try_correction is None:
252            max_try_correction = 3 if torch._C._get_tracing_state() else 10
253
254        sample_shape = torch.Size(sample_shape)
255        sample = self._bartlett_sampling(sample_shape)
256
257        # Below part is to improve numerical stability temporally and should be removed in the future
258        is_singular = self.support.check(sample)
259        if self._batch_shape:
260            is_singular = is_singular.amax(self._batch_dims)
261
262        if torch._C._get_tracing_state():
263            # Less optimized version for JIT
264            for _ in range(max_try_correction):
265                sample_new = self._bartlett_sampling(sample_shape)
266                sample = torch.where(is_singular, sample_new, sample)
267
268                is_singular = ~self.support.check(sample)
269                if self._batch_shape:
270                    is_singular = is_singular.amax(self._batch_dims)
271
272        else:
273            # More optimized version with data-dependent control flow.
274            if is_singular.any():
275                warnings.warn("Singular sample detected.")
276
277                for _ in range(max_try_correction):
278                    sample_new = self._bartlett_sampling(is_singular[is_singular].shape)
279                    sample[is_singular] = sample_new
280
281                    is_singular_new = ~self.support.check(sample_new)
282                    if self._batch_shape:
283                        is_singular_new = is_singular_new.amax(self._batch_dims)
284                    is_singular[is_singular.clone()] = is_singular_new
285
286                    if not is_singular.any():
287                        break
288
289        return sample
290
291    def log_prob(self, value):
292        if self._validate_args:
293            self._validate_sample(value)
294        nu = self.df  # has shape (batch_shape)
295        p = self._event_shape[-1]  # has singleton shape
296        return (
297            -nu
298            * (
299                p * _log_2 / 2
300                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
301                .log()
302                .sum(-1)
303            )
304            - torch.mvlgamma(nu / 2, p=p)
305            + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet
306            - torch.cholesky_solve(value, self._unbroadcasted_scale_tril)
307            .diagonal(dim1=-2, dim2=-1)
308            .sum(dim=-1)
309            / 2
310        )
311
312    def entropy(self):
313        nu = self.df  # has shape (batch_shape)
314        p = self._event_shape[-1]  # has singleton shape
315        V = self.covariance_matrix  # has shape (batch_shape x event_shape)
316        return (
317            (p + 1)
318            * (
319                p * _log_2 / 2
320                + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1)
321                .log()
322                .sum(-1)
323            )
324            + torch.mvlgamma(nu / 2, p=p)
325            - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p)
326            + nu * p / 2
327        )
328
329    @property
330    def _natural_params(self):
331        nu = self.df  # has shape (batch_shape)
332        p = self._event_shape[-1]  # has singleton shape
333        return -self.precision_matrix / 2, (nu - p - 1) / 2
334
335    def _log_normalizer(self, x, y):
336        p = self._event_shape[-1]
337        return (y + (p + 1) / 2) * (
338            -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p
339        ) + torch.mvlgamma(y + (p + 1) / 2, p=p)
340