• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Normal distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26class Prob(nn.Cell):
27    """
28    Test class: probability of Normal distribution.
29    """
30    def __init__(self):
31        super(Prob, self).__init__()
32        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
33
34    def construct(self, x_):
35        return self.n.prob(x_)
36
37def test_pdf():
38    """
39    Test pdf.
40    """
41    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
42    expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32)
43    pdf = Prob()
44    output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
45    tol = 1e-6
46    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
47
48class LogProb(nn.Cell):
49    """
50    Test class: log probability of Normal distribution.
51    """
52    def __init__(self):
53        super(LogProb, self).__init__()
54        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
55
56    def construct(self, x_):
57        return self.n.log_prob(x_)
58
59def test_log_likelihood():
60    """
61    Test log_pdf.
62    """
63    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
64    expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
65    logprob = LogProb()
66    output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
67    tol = 1e-6
68    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
69
70
71class KL(nn.Cell):
72    """
73    Test class: kl_loss of Normal distribution.
74    """
75    def __init__(self):
76        super(KL, self).__init__()
77        self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
78
79    def construct(self, x_, y_):
80        return self.n.kl_loss('Normal', x_, y_)
81
82
83def test_kl_loss():
84    """
85    Test kl_loss.
86    """
87    mean_a = np.array([3.0]).astype(np.float32)
88    sd_a = np.array([4.0]).astype(np.float32)
89
90    mean_b = np.array([1.0]).astype(np.float32)
91    sd_b = np.array([1.0]).astype(np.float32)
92
93    diff_log_scale = np.log(sd_a) - np.log(sd_b)
94    squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
95    expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale
96
97    kl_loss = KL()
98    mean = Tensor(mean_b, dtype=dtype.float32)
99    sd = Tensor(sd_b, dtype=dtype.float32)
100    output = kl_loss(mean, sd)
101    tol = 1e-6
102    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
103
104class Basics(nn.Cell):
105    """
106    Test class: mean/sd/mode of Normal distribution.
107    """
108    def __init__(self):
109        super(Basics, self).__init__()
110        self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
111
112    def construct(self):
113        return self.n.mean(), self.n.sd(), self.n.mode()
114
115def test_basics():
116    """
117    Test mean/standard deviation/mode.
118    """
119    basics = Basics()
120    mean, sd, mode = basics()
121    expect_mean = [3.0, 3.0]
122    expect_sd = [2.0, 4.0]
123    tol = 1e-6
124    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
125    assert (np.abs(mode.asnumpy() - expect_mean) < tol).all()
126    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
127
128class Sampling(nn.Cell):
129    """
130    Test class: sample of Normal distribution.
131    """
132    def __init__(self, shape, seed=0):
133        super(Sampling, self).__init__()
134        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
135        self.shape = shape
136
137    def construct(self, mean=None, sd=None):
138        return self.n.sample(self.shape, mean, sd)
139
140def test_sample():
141    """
142    Test sample.
143    """
144    shape = (2, 3)
145    seed = 10
146    mean = Tensor([2.0], dtype=dtype.float32)
147    sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
148    sample = Sampling(shape, seed=seed)
149    output = sample(mean, sd)
150    assert output.shape == (2, 3, 3)
151
152class CDF(nn.Cell):
153    """
154    Test class: cdf of Normal distribution.
155    """
156    def __init__(self):
157        super(CDF, self).__init__()
158        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
159
160    def construct(self, x_):
161        return self.n.cdf(x_)
162
163
164def test_cdf():
165    """
166    Test cdf.
167    """
168    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
169    expect_cdf = norm_benchmark.cdf([1.0, 2.0]).astype(np.float32)
170    cdf = CDF()
171    output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
172    tol = 2e-5
173    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
174
175class LogCDF(nn.Cell):
176    """
177    Test class: log_cdf of Mormal distribution.
178    """
179    def __init__(self):
180        super(LogCDF, self).__init__()
181        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
182
183    def construct(self, x_):
184        return self.n.log_cdf(x_)
185
186def test_log_cdf():
187    """
188    Test log cdf.
189    """
190    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
191    expect_logcdf = norm_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
192    logcdf = LogCDF()
193    output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
194    tol = 5e-5
195    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
196
197class SF(nn.Cell):
198    """
199    Test class: survival function of Normal distribution.
200    """
201    def __init__(self):
202        super(SF, self).__init__()
203        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
204
205    def construct(self, x_):
206        return self.n.survival_function(x_)
207
208def test_survival():
209    """
210    Test log_survival.
211    """
212    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
213    expect_survival = norm_benchmark.sf([1.0, 2.0]).astype(np.float32)
214    survival_function = SF()
215    output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
216    tol = 2e-5
217    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
218
219class LogSF(nn.Cell):
220    """
221    Test class: log survival function of Normal distribution.
222    """
223    def __init__(self):
224        super(LogSF, self).__init__()
225        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
226
227    def construct(self, x_):
228        return self.n.log_survival(x_)
229
230def test_log_survival():
231    """
232    Test log_survival.
233    """
234    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
235    expect_log_survival = norm_benchmark.logsf([1.0, 2.0]).astype(np.float32)
236    log_survival = LogSF()
237    output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
238    tol = 2e-5
239    assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
240
241class EntropyH(nn.Cell):
242    """
243    Test class: entropy of Normal distribution.
244    """
245    def __init__(self):
246        super(EntropyH, self).__init__()
247        self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
248
249    def construct(self):
250        return self.n.entropy()
251
252def test_entropy():
253    """
254    Test entropy.
255    """
256    norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]]))
257    expect_entropy = norm_benchmark.entropy().astype(np.float32)
258    entropy = EntropyH()
259    output = entropy()
260    tol = 1e-6
261    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
262
263class CrossEntropy(nn.Cell):
264    """
265    Test class: cross entropy between Normal distributions.
266    """
267    def __init__(self):
268        super(CrossEntropy, self).__init__()
269        self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
270
271    def construct(self, x_, y_):
272        entropy = self.n.entropy()
273        kl_loss = self.n.kl_loss('Normal', x_, y_)
274        h_sum_kl = entropy + kl_loss
275        cross_entropy = self.n.cross_entropy('Normal', x_, y_)
276        return h_sum_kl - cross_entropy
277
278def test_cross_entropy():
279    """
280    Test cross_entropy.
281    """
282    cross_entropy = CrossEntropy()
283    mean = Tensor([1.0], dtype=dtype.float32)
284    sd = Tensor([1.0], dtype=dtype.float32)
285    diff = cross_entropy(mean, sd)
286    tol = 1e-6
287    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
288
289class Net(nn.Cell):
290    """
291    Test class: expand single distribution instance to multiple graphs
292    by specifying the attributes.
293    """
294
295    def __init__(self):
296        super(Net, self).__init__()
297        self.normal = msd.Normal(0., 1., dtype=dtype.float32)
298
299    def construct(self, x_, y_):
300        kl = self.normal.kl_loss('Normal', x_, y_)
301        prob = self.normal.prob(kl)
302        return prob
303
304def test_multiple_graphs():
305    """
306    Test multiple graphs case.
307    """
308    prob = Net()
309    mean_a = np.array([0.0]).astype(np.float32)
310    sd_a = np.array([1.0]).astype(np.float32)
311    mean_b = np.array([1.0]).astype(np.float32)
312    sd_b = np.array([1.0]).astype(np.float32)
313    ans = prob(Tensor(mean_b), Tensor(sd_b))
314
315    diff_log_scale = np.log(sd_a) - np.log(sd_b)
316    squared_diff = np.square(mean_a / sd_b - mean_b / sd_b)
317    expect_kl_loss = 0.5 * squared_diff + 0.5 * \
318        np.expm1(2 * diff_log_scale) - diff_log_scale
319
320    norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0]))
321    expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32)
322
323    tol = 1e-6
324    assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
325