• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Bernoulli distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26
27class Prob(nn.Cell):
28    """
29    Test class: probability of Bernoulli distribution.
30    """
31
32    def __init__(self):
33        super(Prob, self).__init__()
34        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
35
36    def construct(self, x_):
37        return self.b.prob(x_)
38
39
40def test_pmf():
41    """
42    Test pmf.
43    """
44    bernoulli_benchmark = stats.bernoulli(0.7)
45    expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32)
46    pmf = Prob()
47    x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
48        np.int32), dtype=dtype.float32)
49    output = pmf(x_)
50    tol = 1e-6
51    assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
52
53
54class LogProb(nn.Cell):
55    """
56    Test class: log probability of Bernoulli distribution.
57    """
58
59    def __init__(self):
60        super(LogProb, self).__init__()
61        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
62
63    def construct(self, x_):
64        return self.b.log_prob(x_)
65
66
67def test_log_likelihood():
68    """
69    Test log_pmf.
70    """
71    bernoulli_benchmark = stats.bernoulli(0.7)
72    expect_logpmf = bernoulli_benchmark.logpmf(
73        [0, 1, 0, 1, 1]).astype(np.float32)
74    logprob = LogProb()
75    x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
76        np.int32), dtype=dtype.float32)
77    output = logprob(x_)
78    tol = 1e-6
79    assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
80
81
82class KL(nn.Cell):
83    """
84    Test class: kl_loss between Bernoulli distributions.
85    """
86
87    def __init__(self):
88        super(KL, self).__init__()
89        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
90
91    def construct(self, x_):
92        return self.b.kl_loss('Bernoulli', x_)
93
94
95def test_kl_loss():
96    """
97    Test kl_loss.
98    """
99    probs1_a = 0.7
100    probs1_b = 0.5
101    probs0_a = 1 - probs1_a
102    probs0_b = 1 - probs1_b
103    expect_kl_loss = probs1_a * \
104        np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b)
105    kl_loss = KL()
106    output = kl_loss(Tensor([probs1_b], dtype=dtype.float32))
107    tol = 1e-6
108    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
109
110
111class Basics(nn.Cell):
112    """
113    Test class: mean/sd/mode of Bernoulli distribution.
114    """
115
116    def __init__(self):
117        super(Basics, self).__init__()
118        self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32)
119
120    def construct(self):
121        return self.b.mean(), self.b.sd(), self.b.mode()
122
123
124def test_basics():
125    """
126    Test mean/standard deviation/mode.
127    """
128    basics = Basics()
129    mean, sd, mode = basics()
130    expect_mean = [0.3, 0.5, 0.7]
131    expect_sd = np.sqrt(np.multiply([0.7, 0.5, 0.3], [0.3, 0.5, 0.7]))
132    expect_mode = [0.0, 0.0, 1.0]
133    tol = 1e-6
134    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
135    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
136    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
137
138
139class Sampling(nn.Cell):
140    """
141    Test class: log probability of Bernoulli distribution.
142    """
143
144    def __init__(self, shape, seed=0):
145        super(Sampling, self).__init__()
146        self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32)
147        self.shape = shape
148
149    def construct(self, probs=None):
150        return self.b.sample(self.shape, probs)
151
152
153def test_sample():
154    """
155    Test sample.
156    """
157    shape = (2, 3)
158    sample = Sampling(shape)
159    output = sample()
160    assert output.shape == (2, 3, 2)
161
162
163class CDF(nn.Cell):
164    """
165    Test class: cdf of bernoulli distributions.
166    """
167
168    def __init__(self):
169        super(CDF, self).__init__()
170        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
171
172    def construct(self, x_):
173        return self.b.cdf(x_)
174
175
176def test_cdf():
177    """
178    Test cdf.
179    """
180    bernoulli_benchmark = stats.bernoulli(0.7)
181    expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32)
182    x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
183        np.int32), dtype=dtype.float32)
184    cdf = CDF()
185    output = cdf(x_)
186    tol = 1e-6
187    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
188
189
190class LogCDF(nn.Cell):
191    """
192    Test class: log cdf of  bernoulli distributions.
193    """
194
195    def __init__(self):
196        super(LogCDF, self).__init__()
197        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
198
199    def construct(self, x_):
200        return self.b.log_cdf(x_)
201
202
203def test_logcdf():
204    """
205    Test log_cdf.
206    """
207    bernoulli_benchmark = stats.bernoulli(0.7)
208    expect_logcdf = bernoulli_benchmark.logcdf(
209        [0, 0, 1, 0, 1]).astype(np.float32)
210    x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
211        np.int32), dtype=dtype.float32)
212    logcdf = LogCDF()
213    output = logcdf(x_)
214    tol = 1e-6
215    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
216
217
218class SF(nn.Cell):
219    """
220    Test class: survival function of Bernoulli distributions.
221    """
222
223    def __init__(self):
224        super(SF, self).__init__()
225        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
226
227    def construct(self, x_):
228        return self.b.survival_function(x_)
229
230
231def test_survival():
232    """
233    Test survival function.
234    """
235    bernoulli_benchmark = stats.bernoulli(0.7)
236    expect_survival = bernoulli_benchmark.sf(
237        [0, 1, 1, 0, 0]).astype(np.float32)
238    x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
239        np.int32), dtype=dtype.float32)
240    sf = SF()
241    output = sf(x_)
242    tol = 1e-6
243    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
244
245
246class LogSF(nn.Cell):
247    """
248    Test class: log survival function of Bernoulli distributions.
249    """
250
251    def __init__(self):
252        super(LogSF, self).__init__()
253        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
254
255    def construct(self, x_):
256        return self.b.log_survival(x_)
257
258
259def test_log_survival():
260    """
261    Test log survival function.
262    """
263    bernoulli_benchmark = stats.bernoulli(0.7)
264    expect_logsurvival = bernoulli_benchmark.logsf(
265        [-1, 0.9, 0, 0, 0]).astype(np.float32)
266    x_ = Tensor(np.array([-1, 0.9, 0, 0, 0]
267                         ).astype(np.float32), dtype=dtype.float32)
268    log_sf = LogSF()
269    output = log_sf(x_)
270    tol = 1e-6
271    assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
272
273
274class EntropyH(nn.Cell):
275    """
276    Test class: entropy of Bernoulli distributions.
277    """
278
279    def __init__(self):
280        super(EntropyH, self).__init__()
281        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
282
283    def construct(self):
284        return self.b.entropy()
285
286
287def test_entropy():
288    """
289    Test entropy.
290    """
291    bernoulli_benchmark = stats.bernoulli(0.7)
292    expect_entropy = bernoulli_benchmark.entropy().astype(np.float32)
293    entropy = EntropyH()
294    output = entropy()
295    tol = 1e-6
296    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
297
298
299class CrossEntropy(nn.Cell):
300    """
301    Test class: cross entropy between bernoulli distributions.
302    """
303
304    def __init__(self):
305        super(CrossEntropy, self).__init__()
306        self.b = msd.Bernoulli(0.7, dtype=dtype.int32)
307
308    def construct(self, x_):
309        entropy = self.b.entropy()
310        kl_loss = self.b.kl_loss('Bernoulli', x_)
311        h_sum_kl = entropy + kl_loss
312        cross_entropy = self.b.cross_entropy('Bernoulli', x_)
313        return h_sum_kl - cross_entropy
314
315
316def test_cross_entropy():
317    """
318    Test cross_entropy.
319    """
320    cross_entropy = CrossEntropy()
321    prob = Tensor([0.3], dtype=dtype.float32)
322    diff = cross_entropy(prob)
323    tol = 1e-6
324    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
325