• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for LogNormal distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26class Prob(nn.Cell):
27    """
28    Test class: probability of LogNormal distribution.
29    """
30    def __init__(self):
31        super(Prob, self).__init__()
32        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
33
34    def construct(self, x_):
35        return self.ln.prob(x_)
36
37def test_pdf():
38    """
39    Test pdf.
40    """
41    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
42    expect_pdf = lognorm_benchmark.pdf([1.0, 2.0]).astype(np.float32)
43    pdf = Prob()
44    output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
45    tol = 1e-6
46    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
47
48class LogProb(nn.Cell):
49    """
50    Test class: log probability of LogNormal distribution.
51    """
52    def __init__(self):
53        super(LogProb, self).__init__()
54        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
55
56    def construct(self, x_):
57        return self.ln.log_prob(x_)
58
59def test_log_likelihood():
60    """
61    Test log_pdf.
62    """
63    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
64    expect_logpdf = lognorm_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
65    logprob = LogProb()
66    output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
67    tol = 1e-6
68    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
69
70class KL(nn.Cell):
71    """
72    Test class: kl_loss of LogNormal distribution.
73    """
74    def __init__(self):
75        super(KL, self).__init__()
76        self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32)
77
78    def construct(self, x_, y_):
79        return self.ln.kl_loss('LogNormal', x_, y_)
80
81def test_kl_loss():
82    """
83    Test kl_loss.
84    """
85    mean_a = np.array([0.3]).astype(np.float32)
86    sd_a = np.array([0.4]).astype(np.float32)
87
88    mean_b = np.array([1.0]).astype(np.float32)
89    sd_b = np.array([1.0]).astype(np.float32)
90
91    diff_log_scale = np.log(sd_a) - np.log(sd_b)
92    squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
93    expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale
94
95    kl_loss = KL()
96    mean = Tensor(mean_b, dtype=dtype.float32)
97    sd = Tensor(sd_b, dtype=dtype.float32)
98    output = kl_loss(mean, sd)
99    tol = 1e-6
100    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
101
102class Basics(nn.Cell):
103    """
104    Test class: mean/sd/mode of LogNormal distribution.
105    """
106    def __init__(self):
107        super(Basics, self).__init__()
108        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
109
110    def construct(self):
111        return self.ln.mean(), self.ln.sd(), self.ln.mode()
112
113def test_basics():
114    """
115    Test mean/standard deviation/mode.
116    """
117    basics = Basics()
118    mean, sd, mode = basics()
119    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
120    expect_mean = lognorm_benchmark.mean().astype(np.float32)
121    expect_sd = lognorm_benchmark.std().astype(np.float32)
122    expect_mode = (lognorm_benchmark.median() / np.exp(np.square([[0.2], [0.4]]))).astype(np.float32)
123    tol = 1e-6
124    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
125    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
126    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
127
128class Sampling(nn.Cell):
129    """
130    Test class: sample of LogNormal distribution.
131    """
132    def __init__(self, shape, seed=0):
133        super(Sampling, self).__init__()
134        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), seed=seed, dtype=dtype.float32)
135        self.shape = shape
136
137    def construct(self, mean=None, sd=None):
138        return self.ln.sample(self.shape, mean, sd)
139
140def test_sample():
141    """
142    Test sample.
143    """
144    shape = (2, 3)
145    seed = 10
146    mean = Tensor([2.0], dtype=dtype.float32)
147    sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
148    sample = Sampling(shape, seed=seed)
149    output = sample(mean, sd)
150    assert output.shape == (2, 3, 3)
151
152class CDF(nn.Cell):
153    """
154    Test class: cdf of LogNormal distribution.
155    """
156    def __init__(self):
157        super(CDF, self).__init__()
158        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
159
160    def construct(self, x_):
161        return self.ln.cdf(x_)
162
163def test_cdf():
164    """
165    Test cdf.
166    """
167    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
168    expect_cdf = lognorm_benchmark.cdf([1.0, 2.0]).astype(np.float32)
169    cdf = CDF()
170    output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
171    tol = 2e-5
172    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
173
174class LogCDF(nn.Cell):
175    """
176    Test class: log_cdf of Mormal distribution.
177    """
178    def __init__(self):
179        super(LogCDF, self).__init__()
180        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
181
182    def construct(self, x_):
183        return self.ln.log_cdf(x_)
184
185def test_log_cdf():
186    """
187    Test log cdf.
188    """
189    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
190    expect_logcdf = lognorm_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
191    logcdf = LogCDF()
192    output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
193    tol = 1e-4
194    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
195
196class SF(nn.Cell):
197    """
198    Test class: survival function of LogNormal distribution.
199    """
200    def __init__(self):
201        super(SF, self).__init__()
202        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
203
204    def construct(self, x_):
205        return self.ln.survival_function(x_)
206
207def test_survival():
208    """
209    Test log_survival.
210    """
211    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
212    expect_survival = lognorm_benchmark.sf([1.0, 2.0]).astype(np.float32)
213    survival_function = SF()
214    output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
215    tol = 2e-5
216    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
217
218class LogSF(nn.Cell):
219    """
220    Test class: log survival function of LogNormal distribution.
221    """
222    def __init__(self):
223        super(LogSF, self).__init__()
224        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
225
226    def construct(self, x_):
227        return self.ln.log_survival(x_)
228
229def test_log_survival():
230    """
231    Test log_survival.
232    """
233    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
234    expect_log_survival = lognorm_benchmark.logsf([1.0, 2.0]).astype(np.float32)
235    log_survival = LogSF()
236    output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
237    tol = 5e-4
238    assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
239
240class EntropyH(nn.Cell):
241    """
242    Test class: entropy of LogNormal distribution.
243    """
244    def __init__(self):
245        super(EntropyH, self).__init__()
246        self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32)
247
248    def construct(self):
249        return self.ln.entropy()
250
251def test_entropy():
252    """
253    Test entropy.
254    """
255    lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3])))
256    expect_entropy = lognorm_benchmark.entropy().astype(np.float32)
257    entropy = EntropyH()
258    output = entropy()
259    tol = 1e-6
260    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
261
262class CrossEntropy(nn.Cell):
263    """
264    Test class: cross entropy between LogNormal distributions.
265    """
266    def __init__(self):
267        super(CrossEntropy, self).__init__()
268        self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32)
269
270    def construct(self, x_, y_):
271        entropy = self.ln.entropy()
272        kl_loss = self.ln.kl_loss('LogNormal', x_, y_)
273        h_sum_kl = entropy + kl_loss
274        cross_entropy = self.ln.cross_entropy('LogNormal', x_, y_)
275        return h_sum_kl - cross_entropy
276
277def test_cross_entropy():
278    """
279    Test cross_entropy.
280    """
281    cross_entropy = CrossEntropy()
282    mean = Tensor([1.0], dtype=dtype.float32)
283    sd = Tensor([1.0], dtype=dtype.float32)
284    diff = cross_entropy(mean, sd)
285    tol = 1e-6
286    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
287
288class Net(nn.Cell):
289    """
290    Test class: expand single distribution instance to multiple graphs
291    by specifying the attributes.
292    """
293
294    def __init__(self):
295        super(Net, self).__init__()
296        self.LogNormal = msd.LogNormal(0., 1., dtype=dtype.float32)
297
298    def construct(self, x_, y_):
299        kl = self.LogNormal.kl_loss('LogNormal', x_, y_)
300        prob = self.LogNormal.prob(kl)
301        return prob
302
303def test_multiple_graphs():
304    """
305    Test multiple graphs case.
306    """
307    prob = Net()
308    mean_a = np.array([0.0]).astype(np.float32)
309    sd_a = np.array([1.0]).astype(np.float32)
310    mean_b = np.array([1.0]).astype(np.float32)
311    sd_b = np.array([1.0]).astype(np.float32)
312    ans = prob(Tensor(mean_b), Tensor(sd_b))
313
314    diff_log_scale = np.log(sd_a) - np.log(sd_b)
315    squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
316    expect_kl_loss = 0.5 * squared_diff + 0.5 * \
317        np.expm1(2 * diff_log_scale) - diff_log_scale
318    lognorm_benchmark = stats.lognorm(s=np.array([1.]), scale=np.exp(np.array([0.])))
319    expect_prob = lognorm_benchmark.pdf(expect_kl_loss).astype(np.float32)
320
321    tol = 1e-6
322    assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
323