• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for cat distribution"""
16import numpy as np
17import pytest
18from scipy import stats
19import mindspore.context as context
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import Tensor
23from mindspore import dtype
24
25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
26
27
28class Prob(nn.Cell):
29    """
30    Test class: probability of categorical distribution.
31    """
32
33    def __init__(self):
34        super(Prob, self).__init__()
35        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
36
37    def construct(self, x_):
38        return self.c.prob(x_)
39
40
41def test_pmf():
42    """
43    Test pmf.
44    """
45    expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3]
46    pmf = Prob()
47    x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
48        np.int32), dtype=dtype.float32)
49    output = pmf(x_)
50    tol = 1e-6
51    assert (np.abs(output.asnumpy() - expect_pmf) < tol).all()
52
53
54class LogProb(nn.Cell):
55    """
56    Test class: log probability of categorical distribution.
57    """
58
59    def __init__(self):
60        super(LogProb, self).__init__()
61        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
62
63    def construct(self, x_):
64        return self.c.log_prob(x_)
65
66
67def test_log_likelihood():
68    """
69    Test log_pmf.
70    """
71    expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3])
72    logprob = LogProb()
73    x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype(
74        np.int32), dtype=dtype.float32)
75    output = logprob(x_)
76    tol = 1e-6
77    assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all()
78
79
80class KL(nn.Cell):
81    """
82    Test class: kl_loss between categorical distributions.
83    """
84
85    def __init__(self):
86        super(KL, self).__init__()
87        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
88
89    def construct(self, x_):
90        return self.c.kl_loss('Categorical', x_)
91
92
93def test_kl_loss():
94    """
95    Test kl_loss.
96    """
97    kl_loss = KL()
98    output = kl_loss(Tensor([0.7, 0.3], dtype=dtype.float32))
99    tol = 1e-6
100    assert (np.abs(output.asnumpy()) < tol).all()
101
102
103class Sampling(nn.Cell):
104    """
105    Test class: sampling of categorical distribution.
106    """
107
108    def __init__(self):
109        super(Sampling, self).__init__()
110        self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
111        self.shape = (2, 3)
112
113    def construct(self):
114        return self.c.sample(self.shape)
115
116
117def test_sample():
118    """
119    Test sample.
120    """
121    with pytest.raises(NotImplementedError):
122        sample = Sampling()
123        sample()
124
125
126class Basics(nn.Cell):
127    """
128    Test class: mean/var/mode of categorical distribution.
129    """
130
131    def __init__(self):
132        super(Basics, self).__init__()
133        self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32)
134
135    def construct(self):
136        return self.c.mean(), self.c.var(), self.c.mode()
137
138
139def test_basics():
140    """
141    Test mean/variance/mode.
142    """
143    basics = Basics()
144    mean, var, mode = basics()
145    expect_mean = 0 * 0.2 + 1 * 0.1 + 2 * 0.7
146    expect_var = 0 * 0.2 + 1 * 0.1 + 4 * 0.7 - (expect_mean * expect_mean)
147    expect_mode = 2
148    tol = 1e-6
149    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
150    assert (np.abs(var.asnumpy() - expect_var) < tol).all()
151    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
152
153
154class CDF(nn.Cell):
155    """
156    Test class: cdf of categorical distributions.
157    """
158
159    def __init__(self):
160        super(CDF, self).__init__()
161        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
162
163    def construct(self, x_):
164        return self.c.cdf(x_)
165
166
167def test_cdf():
168    """
169    Test cdf.
170    """
171    expect_cdf = [0.7, 0.7, 1, 0.7, 1]
172    x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
173        np.int32), dtype=dtype.float32)
174    cdf = CDF()
175    output = cdf(x_)
176    tol = 1e-6
177    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
178
179
180class LogCDF(nn.Cell):
181    """
182    Test class: log cdf of categorical distributions.
183    """
184
185    def __init__(self):
186        super(LogCDF, self).__init__()
187        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
188
189    def construct(self, x_):
190        return self.c.log_cdf(x_)
191
192
193def test_logcdf():
194    """
195    Test log_cdf.
196    """
197    expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1])
198    x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype(
199        np.int32), dtype=dtype.float32)
200    logcdf = LogCDF()
201    output = logcdf(x_)
202    tol = 1e-6
203    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
204
205
206class SF(nn.Cell):
207    """
208    Test class: survival function of categorical distributions.
209    """
210
211    def __init__(self):
212        super(SF, self).__init__()
213        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
214
215    def construct(self, x_):
216        return self.c.survival_function(x_)
217
218
219def test_survival():
220    """
221    Test survival function.
222    """
223    expect_survival = [0.3, 0., 0., 0.3, 0.3]
224    x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype(
225        np.int32), dtype=dtype.float32)
226    sf = SF()
227    output = sf(x_)
228    tol = 1e-6
229    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
230
231
232class LogSF(nn.Cell):
233    """
234    Test class: log survival function of categorical distributions.
235    """
236
237    def __init__(self):
238        super(LogSF, self).__init__()
239        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
240
241    def construct(self, x_):
242        return self.c.log_survival(x_)
243
244
245def test_log_survival():
246    """
247    Test log survival function.
248    """
249    expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3])
250    x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5]
251                         ).astype(np.float32), dtype=dtype.float32)
252    log_sf = LogSF()
253    output = log_sf(x_)
254    tol = 1e-6
255    assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
256
257
258class EntropyH(nn.Cell):
259    """
260    Test class: entropy of categorical distributions.
261    """
262
263    def __init__(self):
264        super(EntropyH, self).__init__()
265        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
266
267    def construct(self):
268        return self.c.entropy()
269
270
271def test_entropy():
272    """
273    Test entropy.
274    """
275    cat_benchmark = stats.multinomial(n=1, p=[0.7, 0.3])
276    expect_entropy = cat_benchmark.entropy().astype(np.float32)
277    entropy = EntropyH()
278    output = entropy()
279    tol = 1e-6
280    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
281
282
283class CrossEntropy(nn.Cell):
284    """
285    Test class: cross entropy between categorical distributions.
286    """
287
288    def __init__(self):
289        super(CrossEntropy, self).__init__()
290        self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32)
291
292    def construct(self, x_):
293        entropy = self.c.entropy()
294        kl_loss = self.c.kl_loss('Categorical', x_)
295        h_sum_kl = entropy + kl_loss
296        cross_entropy = self.c.cross_entropy('Categorical', x_)
297        return h_sum_kl - cross_entropy
298
299
300def test_cross_entropy():
301    """
302    Test cross_entropy.
303    """
304    cross_entropy = CrossEntropy()
305    prob = Tensor([0.7, 0.3], dtype=dtype.float32)
306    diff = cross_entropy(prob)
307    tol = 1e-6
308    assert (np.abs(diff.asnumpy()) < tol).all()
309