• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Uniform distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26class Prob(nn.Cell):
27    """
28    Test class: probability of Uniform distribution.
29    """
30    def __init__(self):
31        super(Prob, self).__init__()
32        self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
33
34    def construct(self, x_):
35        return self.u.prob(x_)
36
37def test_pdf():
38    """
39    Test pdf.
40    """
41    uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]])
42    expect_pdf = uniform_benchmark.pdf([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32)
43    pdf = Prob()
44    x_ = Tensor(np.array([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32), dtype=dtype.float32)
45    output = pdf(x_)
46    tol = 1e-6
47    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
48
49class LogProb(nn.Cell):
50    """
51    Test class: log probability of Uniform distribution.
52    """
53    def __init__(self):
54        super(LogProb, self).__init__()
55        self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32)
56
57    def construct(self, x_):
58        return self.u.log_prob(x_)
59
60def test_log_likelihood():
61    """
62    Test log_pdf.
63    """
64    uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]])
65    expect_logpdf = uniform_benchmark.logpdf([0.5]).astype(np.float32)
66    logprob = LogProb()
67    x_ = Tensor(np.array([0.5]).astype(np.float32), dtype=dtype.float32)
68    output = logprob(x_)
69    tol = 1e-6
70    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
71
72class KL(nn.Cell):
73    """
74    Test class: kl_loss between Uniform distributions.
75    """
76    def __init__(self):
77        super(KL, self).__init__()
78        self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
79
80    def construct(self, x_, y_):
81        return self.u.kl_loss('Uniform', x_, y_)
82
83def test_kl_loss():
84    """
85    Test kl_loss.
86    """
87    low_a = 0.0
88    high_a = 1.5
89    low_b = -1.0
90    high_b = 2.0
91    expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a)
92    kl = KL()
93    output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32))
94    tol = 1e-6
95    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
96
97class Basics(nn.Cell):
98    """
99    Test class: mean/sd of Uniform distribution.
100    """
101    def __init__(self):
102        super(Basics, self).__init__()
103        self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32)
104
105    def construct(self):
106        return self.u.mean(), self.u.sd()
107
108def test_basics():
109    """
110    Test mean/standard deviation.
111    """
112    basics = Basics()
113    mean, sd = basics()
114    expect_mean = [1.5]
115    expect_sd = np.sqrt([0.75])
116    tol = 1e-6
117    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
118    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
119
120class Sampling(nn.Cell):
121    """
122    Test class: sample of Uniform distribution.
123    """
124    def __init__(self, shape, seed=0):
125        super(Sampling, self).__init__()
126        self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32)
127        self.shape = shape
128
129    def construct(self, low=None, high=None):
130        return self.u.sample(self.shape, low, high)
131
132def test_sample():
133    """
134    Test sample.
135    """
136    shape = (2, 3)
137    seed = 10
138    low = Tensor([1.0], dtype=dtype.float32)
139    high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32)
140    sample = Sampling(shape, seed=seed)
141    output = sample(low, high)
142    assert output.shape == (2, 3, 3)
143
144class CDF(nn.Cell):
145    """
146    Test class: cdf of Uniform distribution.
147    """
148    def __init__(self):
149        super(CDF, self).__init__()
150        self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
151
152    def construct(self, x_):
153        return self.u.cdf(x_)
154
155def test_cdf():
156    """
157    Test cdf.
158    """
159    uniform_benchmark = stats.uniform([0.0], [1.0])
160    expect_cdf = uniform_benchmark.cdf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32)
161    cdf = CDF()
162    x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32)
163    output = cdf(x_)
164    tol = 1e-6
165    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
166
167class LogCDF(nn.Cell):
168    """
169    Test class: log_cdf of Uniform distribution.
170    """
171    def __init__(self):
172        super(LogCDF, self).__init__()
173        self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
174
175    def construct(self, x_):
176        return self.u.log_cdf(x_)
177
178class SF(nn.Cell):
179    """
180    Test class: survival function of Uniform distribution.
181    """
182    def __init__(self):
183        super(SF, self).__init__()
184        self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
185
186    def construct(self, x_):
187        return self.u.survival_function(x_)
188
189class LogSF(nn.Cell):
190    """
191    Test class: log survival function of Uniform distribution.
192    """
193    def __init__(self):
194        super(LogSF, self).__init__()
195        self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32)
196
197    def construct(self, x_):
198        return self.u.log_survival(x_)
199
200class EntropyH(nn.Cell):
201    """
202    Test class: entropy of Uniform distribution.
203    """
204    def __init__(self):
205        super(EntropyH, self).__init__()
206        self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32)
207
208    def construct(self):
209        return self.u.entropy()
210
211def test_entropy():
212    """
213    Test entropy.
214    """
215    uniform_benchmark = stats.uniform([0.0], [1.0, 2.0])
216    expect_entropy = uniform_benchmark.entropy().astype(np.float32)
217    entropy = EntropyH()
218    output = entropy()
219    tol = 1e-6
220    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
221
222class CrossEntropy(nn.Cell):
223    """
224    Test class: cross_entropy between Uniform distributions.
225    """
226    def __init__(self):
227        super(CrossEntropy, self).__init__()
228        self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32)
229
230    def construct(self, x_, y_):
231        entropy = self.u.entropy()
232        kl_loss = self.u.kl_loss('Uniform', x_, y_)
233        h_sum_kl = entropy + kl_loss
234        cross_entropy = self.u.cross_entropy('Uniform', x_, y_)
235        return h_sum_kl - cross_entropy
236
237def test_log_cdf():
238    """
239    Test log_cdf.
240    """
241    uniform_benchmark = stats.uniform([0.0], [1.0])
242    expect_logcdf = uniform_benchmark.logcdf([0.5, 0.8, 2.0]).astype(np.float32)
243    logcdf = LogCDF()
244    x_ = Tensor(np.array([0.5, 0.8, 2.0]).astype(np.float32), dtype=dtype.float32)
245    output = logcdf(x_)
246    tol = 1e-6
247    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
248
249def test_survival():
250    """
251    Test survival function.
252    """
253    uniform_benchmark = stats.uniform([0.0], [1.0])
254    expect_survival = uniform_benchmark.sf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32)
255    survival = SF()
256    x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32)
257    output = survival(x_)
258    tol = 1e-6
259    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
260
261def test_log_survival():
262    """
263    Test log survival function.
264    """
265    uniform_benchmark = stats.uniform([0.0], [1.0])
266    expect_logsurvival = uniform_benchmark.logsf([0.5, 0.8, -2.0]).astype(np.float32)
267    logsurvival = LogSF()
268    x_ = Tensor(np.array([0.5, 0.8, -2.0]).astype(np.float32), dtype=dtype.float32)
269    output = logsurvival(x_)
270    tol = 1e-6
271    assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
272
273def test_cross_entropy():
274    """
275    Test cross_entropy.
276    """
277    cross_entropy = CrossEntropy()
278    low_b = -1.0
279    high_b = 2.0
280    diff = cross_entropy(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32))
281    tol = 1e-6
282    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
283