• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.Bernoulli.
17"""
18import pytest
19
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import dtype
23from mindspore import Tensor
24
25
26def test_arguments():
27    """
28    Args passing during initialization.
29    """
30    b = msd.Bernoulli()
31    assert isinstance(b, msd.Distribution)
32    b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32)
33    assert isinstance(b, msd.Distribution)
34
35
36def test_type():
37    with pytest.raises(TypeError):
38        msd.Bernoulli([0.1], dtype=dtype.bool_)
39
40
41def test_name():
42    with pytest.raises(TypeError):
43        msd.Bernoulli([0.1], name=1.0)
44
45
46def test_seed():
47    with pytest.raises(TypeError):
48        msd.Bernoulli([0.1], seed='seed')
49
50
51def test_prob():
52    """
53    Invalid probability.
54    """
55    with pytest.raises(ValueError):
56        msd.Bernoulli([-0.1], dtype=dtype.int32)
57    with pytest.raises(ValueError):
58        msd.Bernoulli([1.1], dtype=dtype.int32)
59    with pytest.raises(ValueError):
60        msd.Bernoulli([0.0], dtype=dtype.int32)
61    with pytest.raises(ValueError):
62        msd.Bernoulli([1.0], dtype=dtype.int32)
63
64
65class BernoulliProb(nn.Cell):
66    """
67    Bernoulli distribution: initialize with probs.
68    """
69
70    def __init__(self):
71        super(BernoulliProb, self).__init__()
72        self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
73
74    def construct(self, value):
75        prob = self.b.prob(value)
76        log_prob = self.b.log_prob(value)
77        cdf = self.b.cdf(value)
78        log_cdf = self.b.log_cdf(value)
79        sf = self.b.survival_function(value)
80        log_sf = self.b.log_survival(value)
81        return prob + log_prob + cdf + log_cdf + sf + log_sf
82
83
84def test_bernoulli_prob():
85    """
86    Test probability functions: passing value through construct.
87    """
88    net = BernoulliProb()
89    value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
90    ans = net(value)
91    assert isinstance(ans, Tensor)
92
93
94class BernoulliProb1(nn.Cell):
95    """
96    Bernoulli distribution: initialize without probs.
97    """
98
99    def __init__(self):
100        super(BernoulliProb1, self).__init__()
101        self.b = msd.Bernoulli(dtype=dtype.int32)
102
103    def construct(self, value, probs):
104        prob = self.b.prob(value, probs)
105        log_prob = self.b.log_prob(value, probs)
106        cdf = self.b.cdf(value, probs)
107        log_cdf = self.b.log_cdf(value, probs)
108        sf = self.b.survival_function(value, probs)
109        log_sf = self.b.log_survival(value, probs)
110        return prob + log_prob + cdf + log_cdf + sf + log_sf
111
112
113def test_bernoulli_prob1():
114    """
115    Test probability functions: passing value/probs through construct.
116    """
117    net = BernoulliProb1()
118    value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
119    probs = Tensor([0.5], dtype=dtype.float32)
120    ans = net(value, probs)
121    assert isinstance(ans, Tensor)
122
123
124class BernoulliKl(nn.Cell):
125    """
126    Test class: kl_loss between Bernoulli distributions.
127    """
128
129    def __init__(self):
130        super(BernoulliKl, self).__init__()
131        self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
132        self.b2 = msd.Bernoulli(dtype=dtype.int32)
133
134    def construct(self, probs_b, probs_a):
135        kl1 = self.b1.kl_loss('Bernoulli', probs_b)
136        kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a)
137        return kl1 + kl2
138
139
140def test_kl():
141    """
142    Test kl_loss function.
143    """
144    ber_net = BernoulliKl()
145    probs_b = Tensor([0.3], dtype=dtype.float32)
146    probs_a = Tensor([0.7], dtype=dtype.float32)
147    ans = ber_net(probs_b, probs_a)
148    assert isinstance(ans, Tensor)
149
150
151class BernoulliCrossEntropy(nn.Cell):
152    """
153    Test class: cross_entropy of Bernoulli distribution.
154    """
155
156    def __init__(self):
157        super(BernoulliCrossEntropy, self).__init__()
158        self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32)
159        self.b2 = msd.Bernoulli(dtype=dtype.int32)
160
161    def construct(self, probs_b, probs_a):
162        h1 = self.b1.cross_entropy('Bernoulli', probs_b)
163        h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a)
164        return h1 + h2
165
166
167def test_cross_entropy():
168    """
169    Test cross_entropy between Bernoulli distributions.
170    """
171    net = BernoulliCrossEntropy()
172    probs_b = Tensor([0.3], dtype=dtype.float32)
173    probs_a = Tensor([0.7], dtype=dtype.float32)
174    ans = net(probs_b, probs_a)
175    assert isinstance(ans, Tensor)
176
177
178class BernoulliConstruct(nn.Cell):
179    """
180    Bernoulli distribution: going through construct.
181    """
182
183    def __init__(self):
184        super(BernoulliConstruct, self).__init__()
185        self.b = msd.Bernoulli(0.5, dtype=dtype.int32)
186        self.b1 = msd.Bernoulli(dtype=dtype.int32)
187
188    def construct(self, value, probs):
189        prob = self.b('prob', value)
190        prob1 = self.b('prob', value, probs)
191        prob2 = self.b1('prob', value, probs)
192        return prob + prob1 + prob2
193
194
195def test_bernoulli_construct():
196    """
197    Test probability function going through construct.
198    """
199    net = BernoulliConstruct()
200    value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32)
201    probs = Tensor([0.5], dtype=dtype.float32)
202    ans = net(value, probs)
203    assert isinstance(ans, Tensor)
204
205
206class BernoulliMean(nn.Cell):
207    """
208    Test class: basic mean/sd/var/mode/entropy function.
209    """
210
211    def __init__(self):
212        super(BernoulliMean, self).__init__()
213        self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
214
215    def construct(self):
216        mean = self.b.mean()
217        return mean
218
219
220def test_mean():
221    """
222    Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
223    """
224    net = BernoulliMean()
225    ans = net()
226    assert isinstance(ans, Tensor)
227
228
229class BernoulliSd(nn.Cell):
230    """
231    Test class: basic mean/sd/var/mode/entropy function.
232    """
233
234    def __init__(self):
235        super(BernoulliSd, self).__init__()
236        self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
237
238    def construct(self):
239        sd = self.b.sd()
240        return sd
241
242
243def test_sd():
244    """
245    Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
246    """
247    net = BernoulliSd()
248    ans = net()
249    assert isinstance(ans, Tensor)
250
251
252class BernoulliVar(nn.Cell):
253    """
254    Test class: basic mean/sd/var/mode/entropy function.
255    """
256
257    def __init__(self):
258        super(BernoulliVar, self).__init__()
259        self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
260
261    def construct(self):
262        var = self.b.var()
263        return var
264
265
266def test_var():
267    """
268    Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
269    """
270    net = BernoulliVar()
271    ans = net()
272    assert isinstance(ans, Tensor)
273
274
275class BernoulliMode(nn.Cell):
276    """
277    Test class: basic mean/sd/var/mode/entropy function.
278    """
279
280    def __init__(self):
281        super(BernoulliMode, self).__init__()
282        self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
283
284    def construct(self):
285        mode = self.b.mode()
286        return mode
287
288
289def test_mode():
290    """
291    Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
292    """
293    net = BernoulliMode()
294    ans = net()
295    assert isinstance(ans, Tensor)
296
297
298class BernoulliEntropy(nn.Cell):
299    """
300    Test class: basic mean/sd/var/mode/entropy function.
301    """
302
303    def __init__(self):
304        super(BernoulliEntropy, self).__init__()
305        self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32)
306
307    def construct(self):
308        entropy = self.b.entropy()
309        return entropy
310
311
312def test_entropy():
313    """
314    Test mean/sd/var/mode/entropy functionality of Bernoulli distribution.
315    """
316    net = BernoulliEntropy()
317    ans = net()
318    assert isinstance(ans, Tensor)
319