• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""LogNormal Distribution"""
16import numpy as np
17from mindspore.ops import operations as P
18from mindspore.common import dtype as mstype
19import mindspore.nn.probability.bijector as msb
20import mindspore.nn.probability.distribution as msd
21from ._utils.utils import check_distribution_name
22from ._utils.custom_ops import exp_generic, log_generic
23
24
25class LogNormal(msd.TransformedDistribution):
26    """
27    LogNormal distribution.
28    A log-normal (or lognormal) distribution is a continuous probability distribution of a random variable whose
29    logarithm is normally distributed. It is constructed as the exponential transformation of a Normal distribution.
30
31    Args:
32        loc (int, float, list, numpy.ndarray, Tensor): The mean of the underlying Normal distribution. Default: None.
33        scale (int, float, list, numpy.ndarray, Tensor): The standard deviation of the underlying
34          Normal distribution. Default: None.
35        seed (int): the seed used in sampling. The global seed is used if it is None. Default: None.
36        dtype (mindspore.dtype): type of the distribution. Default: mstype.float32.
37        name (str): the name of the distribution. Default: 'LogNormal'.
38
39    Supported Platforms:
40        ``Ascend`` ``GPU``
41
42    Note:
43        `scale` must be greater than zero.
44        `dist_spec_args` are `loc` and `scale`.
45        `dtype` must be a float type because LogNormal distributions are continuous.
46
47    Examples:
48        >>> import numpy as np
49        >>> import mindspore
50        >>> import mindspore.nn as nn
51        >>> import mindspore.nn.probability.distribution as msd
52        >>> from mindspore import Tensor
53        >>> class Prob(nn.Cell):
54        ...     def __init__(self):
55        ...         super(Prob, self).__init__()
56        ...         self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=mindspore.float32)
57        ...     def construct(self, x_):
58        ...         return self.ln.prob(x_)
59        >>> pdf = Prob()
60        >>> output = pdf(Tensor([1.0, 2.0], dtype=mindspore.float32))
61        >>> print(output.shape)
62        (2, 2)
63    """
64
65    def __init__(self,
66                 loc=None,
67                 scale=None,
68                 seed=0,
69                 dtype=mstype.float32,
70                 name="LogNormal"):
71        """
72        Constructor of LogNormal distribution.
73        """
74        super(LogNormal, self).__init__(distribution=msd.Normal(loc, scale, dtype=dtype),
75                                        bijector=msb.Exp(),
76                                        seed=seed, name=name)
77
78        # overwrite default_parameters and parameter_names
79        self._reset_parameters()
80        self._loc = self._add_parameter(loc, 'loc')
81        self._scale = self._add_parameter(scale, 'scale')
82
83        self.log_2pi = np.log(2 * np.pi)
84
85        #ops needed for the class
86        self.dtypeop = P.DType()
87        self.exp = exp_generic
88        self.expm1 = P.Expm1()
89        self.log = log_generic
90        self.const = P.ScalarToArray()
91        self.erf = P.Erf()
92        self.fill = P.Fill()
93        self.greater = P.Greater()
94        self.select = P.Select()
95        self.shape = P.Shape()
96        self.sq = P.Square()
97        self.sqrt = P.Sqrt()
98        self.cast = P.Cast()
99        self.squeeze = P.Squeeze(0)
100
101    @property
102    def loc(self):
103        """
104        Distribution parameter for the pre-transformed mean
105        after casting to dtype.
106        """
107        return self._loc
108
109    @property
110    def scale(self):
111        """
112        Distribution parameter for the pre-transformed standard deviation
113        after casting to dtype.
114        """
115        return self._scale
116
117    def _get_dist_type(self):
118        return "LogNormal"
119
120    def _get_dist_args(self, loc=None, scale=None):
121        if loc is not None:
122            self.checktensor(loc, 'loc')
123        else:
124            loc = self.loc
125        if scale is not None:
126            self.checktensor(scale, 'scale')
127        else:
128            scale = self.scale
129        return loc, scale
130
131    def extend_repr(self):
132        """Display instance object as string."""
133        if self.is_scalar_batch:
134            s = 'loc = {}, scale = {}'.format(self.loc, self.scale)
135        else:
136            s = 'batch_shape = {}'.format(self.broadcast_shape)
137        return s
138
139    def _mean(self, loc=None, scale=None):
140        """
141        The mean of the distribution.
142        """
143        mean, sd = self._check_param_type(loc, scale)
144        var = self.distribution("var", mean=mean, sd=sd)
145        return self.exp(mean + 0.5 * var)
146
147    def _mode(self, loc=None, scale=None):
148        """
149        The mode of the distribution.
150        """
151        mean, sd = self._check_param_type(loc, scale)
152        var = self.distribution("var", mean=mean, sd=sd)
153        return self.exp(mean - var)
154
155    def _var(self, loc=None, scale=None):
156        """
157        The variance of the distribution.
158        """
159        mean, sd = self._check_param_type(loc, scale)
160        var = self.distribution("var", mean=mean, sd=sd)
161        return self.expm1(var) * self.exp(2. * mean + var)
162
163    def _entropy(self, loc=None, scale=None):
164        r"""
165        Evaluate entropy.
166
167        .. math::
168            H(X) = μ + 0.5 + \log(σ) + 0.5 * \log(2pi)
169        """
170        mean, sd = self._check_param_type(loc, scale)
171        return mean + 0.5 + self.log(sd) + 0.5 * self.log_2pi
172
173    def _cdf(self, value, loc=None, scale=None):
174        r"""
175        Compute the cdf via the below formula,
176        where g is the exp bijector,
177        and P is the cdf of the underlying normal dist
178        .. math::
179            Y = g(X)
180            P(Y <= a) = P(X <= g^{-1}(a))
181        """
182        mean, sd = self._check_param_type(loc, scale)
183        inverse_value = self.bijector("inverse", value)
184        cdf = self.distribution("cdf", inverse_value, mean, sd)
185
186        # to increase numerical stability, set cdf = 0 when value <= 0
187        zeros = self.fill(self.dtypeop(cdf), self.shape(cdf), 0.0)
188
189        return self.select(self.greater(value, 0.), cdf, zeros)
190
191    def _log_prob(self, value, loc=None, scale=None):
192        r"""
193        Compute the log prob via the below formula,
194        where g is the exp bijector,
195        and P is the pdf of the underlying normal dist
196        .. math::
197            Y = g(X)
198            Py(a) = Px(g^{-1}(a)) * (g^{-1})'(a)
199            \log(Py(a)) = \log(Px(g^{-1}(a))) + \log((g^{-1})'(a))
200        """
201        mean, sd = self._check_param_type(loc, scale)
202        inverse_value = self.bijector("inverse", value)
203        unadjust_prob = self.distribution("log_prob", inverse_value, mean, sd)
204        log_jacobian = self.bijector("inverse_log_jacobian", value)
205        return unadjust_prob + log_jacobian
206
207    def _cross_entropy(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
208        r"""
209        Evaluate cross entropy between lognormal distributions.
210
211        Args:
212            dist (str): The type of the distributions. Should be "LogNormal" in this case.
213            loc_b (Tensor): The loc of distribution b.
214            scale_b (Tensor): The scale of distribution b.
215            loc_a (Tensor): The loc of distribution a. Default: None.
216            scale_a (Tensor): The scale of distribution a. Default: None.
217        """
218        check_distribution_name(dist, 'LogNormal')
219        return self._entropy(loc_a, scale_a) + self._kl_loss(dist, loc_b, scale_b, loc_a, scale_a)
220
221    def _kl_loss(self, dist, loc_b, scale_b, loc_a=None, scale_a=None):
222        r"""
223        Evaluate LogNormal-LogNormal kl divergence, i.e. KL(a||b).
224
225        Args:
226            dist (str): The type of the distributions. Should be "LogNormal" in this case.
227            loc_b (Tensor): The loc of distribution b.
228            scale_b (Tensor): The scale of distribution b.
229            loc_a (Tensor): The loc of distribution a. Default: None.
230            scale_a (Tensor): The scale of distribution a. Default: None.
231
232        .. math::
233            KL(a||b) = 0.5 * (\fract{MEAN(a)}{STD(b)} - \fract{MEAN(b)}{STD(b)}) ^ 2 +
234                       0.5 * EXPM1(2 * (\log(STD(a)) - \log(STD(b))) - (\log(STD(a)) - \log(STD(b)))
235        """
236        check_distribution_name(dist, 'LogNormal')
237        return self.distribution("kl_loss", 'Normal', loc_b, scale_b, loc_a, scale_a)
238
239    def _sample(self, shape=(), loc=None, scale=None):
240        r"""
241        Generate samples via mapping the samples from the underlying normal dist.
242        """
243        shape = self.checktuple(shape, 'shape')
244        mean, sd = self._check_param_type(loc, scale)
245        if shape == ():
246            sample_shape = (1,)
247        else:
248            sample_shape = shape
249        org_sample = self.distribution("sample", sample_shape, mean, sd)
250        org_sample = self.cast(org_sample, self.dtype)
251        value = self.bijector("forward", org_sample)
252        if shape == ():
253            value = self.squeeze(value)
254        return value
255