• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Gumbel distribution"""
16import numpy as np
17from scipy import stats
18from scipy import special
19import mindspore.context as context
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import Tensor
23from mindspore import dtype
24
25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
26
27class Prob(nn.Cell):
28    """
29    Test class: probability of Gumbel distribution.
30    """
31    def __init__(self):
32        super(Prob, self).__init__()
33        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
34
35    def construct(self, x_):
36        return self.gum.prob(x_)
37
38def test_pdf():
39    """
40    Test pdf.
41    """
42    loc = np.array([0.0]).astype(np.float32)
43    scale = np.array([[1.0], [2.0]]).astype(np.float32)
44    gumbel_benchmark = stats.gumbel_r(loc, scale)
45    value = np.array([1.0, 2.0]).astype(np.float32)
46    expect_pdf = gumbel_benchmark.pdf(value).astype(np.float32)
47    pdf = Prob()
48    output = pdf(Tensor(value, dtype=dtype.float32))
49    tol = 1e-6
50    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
51
52class LogProb(nn.Cell):
53    """
54    Test class: log probability of Gumbel distribution.
55    """
56    def __init__(self):
57        super(LogProb, self).__init__()
58        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
59
60    def construct(self, x_):
61        return self.gum.log_prob(x_)
62
63def test_log_likelihood():
64    """
65    Test log_pdf.
66    """
67    loc = np.array([0.0]).astype(np.float32)
68    scale = np.array([[1.0], [2.0]]).astype(np.float32)
69    gumbel_benchmark = stats.gumbel_r(loc, scale)
70    expect_logpdf = gumbel_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
71    logprob = LogProb()
72    output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
73    tol = 1e-6
74    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
75
76class KL(nn.Cell):
77    """
78    Test class: kl_loss of Gumbel distribution.
79    """
80    def __init__(self):
81        super(KL, self).__init__()
82        self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32)
83
84    def construct(self, loc_b, scale_b):
85        return self.gum.kl_loss('Gumbel', loc_b, scale_b)
86
87def test_kl_loss():
88    """
89    Test kl_loss.
90    """
91    loc = np.array([0.0]).astype(np.float32)
92    scale = np.array([1.0, 2.0]).astype(np.float32)
93
94    loc_b = np.array([1.0]).astype(np.float32)
95    scale_b = np.array([1.0, 2.0]).astype(np.float32)
96
97    expect_kl_loss = np.log(scale_b) - np.log(scale) +\
98               np.euler_gamma * (scale / scale_b - 1.) +\
99               np.expm1((loc_b - loc) / scale_b + special.loggamma(scale / scale_b + 1.))
100
101    kl_loss = KL()
102    loc_b = Tensor(loc_b, dtype=dtype.float32)
103    scale_b = Tensor(scale_b, dtype=dtype.float32)
104    output = kl_loss(loc_b, scale_b)
105    tol = 1e-5
106    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
107
108class Basics(nn.Cell):
109    """
110    Test class: mean/sd/mode of Gumbel distribution.
111    """
112    def __init__(self):
113        super(Basics, self).__init__()
114        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
115
116    def construct(self):
117        return self.gum.mean(), self.gum.sd(), self.gum.mode()
118
119def test_basics():
120    """
121    Test mean/standard deviation/mode.
122    """
123    basics = Basics()
124    mean, sd, mode = basics()
125
126    loc = np.array([0.0]).astype(np.float32)
127    scale = np.array([[1.0], [2.0]]).astype(np.float32)
128    gumbel_benchmark = stats.gumbel_r(loc, scale)
129    expect_mean = gumbel_benchmark.mean().astype(np.float32)
130    expect_sd = gumbel_benchmark.std().astype(np.float32)
131    expect_mode = np.array([[0.0], [0.0]]).astype(np.float32)
132    tol = 1e-6
133    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
134    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
135    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
136
137class Sampling(nn.Cell):
138    """
139    Test class: sample of Gumbel distribution.
140    """
141    def __init__(self, shape, seed=0):
142        super(Sampling, self).__init__()
143        self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed)
144        self.shape = shape
145
146    def construct(self):
147        return self.gum.sample(self.shape)
148
149def test_sample():
150    """
151    Test sample.
152    """
153    shape = (2, 3)
154    seed = 10
155    sample = Sampling(shape, seed=seed)
156    output = sample()
157    assert output.shape == (2, 3, 3)
158
159class CDF(nn.Cell):
160    """
161    Test class: cdf of Gumbel distribution.
162    """
163    def __init__(self):
164        super(CDF, self).__init__()
165        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
166
167    def construct(self, x_):
168        return self.gum.cdf(x_)
169
170def test_cdf():
171    """
172    Test cdf.
173    """
174    loc = np.array([0.0]).astype(np.float32)
175    scale = np.array([[1.0], [2.0]]).astype(np.float32)
176    gumbel_benchmark = stats.gumbel_r(loc, scale)
177    expect_cdf = gumbel_benchmark.cdf([1.0, 2.0]).astype(np.float32)
178    cdf = CDF()
179    output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
180    tol = 2e-5
181    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
182
183class LogCDF(nn.Cell):
184    """
185    Test class: log_cdf of Gumbel distribution.
186    """
187    def __init__(self):
188        super(LogCDF, self).__init__()
189        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
190
191    def construct(self, x_):
192        return self.gum.log_cdf(x_)
193
194def test_log_cdf():
195    """
196    Test log cdf.
197    """
198    loc = np.array([0.0]).astype(np.float32)
199    scale = np.array([[1.0], [2.0]]).astype(np.float32)
200    gumbel_benchmark = stats.gumbel_r(loc, scale)
201    expect_logcdf = gumbel_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
202    logcdf = LogCDF()
203    output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
204    tol = 1e-4
205    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
206
207class SF(nn.Cell):
208    """
209    Test class: survival function of Gumbel distribution.
210    """
211    def __init__(self):
212        super(SF, self).__init__()
213        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
214
215    def construct(self, x_):
216        return self.gum.survival_function(x_)
217
218def test_survival():
219    """
220    Test log_survival.
221    """
222    loc = np.array([0.0]).astype(np.float32)
223    scale = np.array([[1.0], [2.0]]).astype(np.float32)
224    gumbel_benchmark = stats.gumbel_r(loc, scale)
225    expect_survival = gumbel_benchmark.sf([1.0, 2.0]).astype(np.float32)
226    survival_function = SF()
227    output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
228    tol = 2e-5
229    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
230
231class LogSF(nn.Cell):
232    """
233    Test class: log survival function of Gumbel distribution.
234    """
235    def __init__(self):
236        super(LogSF, self).__init__()
237        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
238
239    def construct(self, x_):
240        return self.gum.log_survival(x_)
241
242def test_log_survival():
243    """
244    Test log_survival.
245    """
246    loc = np.array([0.0]).astype(np.float32)
247    scale = np.array([[1.0], [2.0]]).astype(np.float32)
248    gumbel_benchmark = stats.gumbel_r(loc, scale)
249    expect_log_survival = gumbel_benchmark.logsf([1.0, 2.0]).astype(np.float32)
250    log_survival = LogSF()
251    output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
252    tol = 5e-4
253    assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
254
255class EntropyH(nn.Cell):
256    """
257    Test class: entropy of Gumbel distribution.
258    """
259    def __init__(self):
260        super(EntropyH, self).__init__()
261        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
262
263    def construct(self):
264        return self.gum.entropy()
265
266def test_entropy():
267    """
268    Test entropy.
269    """
270    loc = np.array([0.0]).astype(np.float32)
271    scale = np.array([[1.0], [2.0]]).astype(np.float32)
272    gumbel_benchmark = stats.gumbel_r(loc, scale)
273    expect_entropy = gumbel_benchmark.entropy().astype(np.float32)
274    entropy = EntropyH()
275    output = entropy()
276    tol = 1e-6
277    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
278
279class CrossEntropy(nn.Cell):
280    """
281    Test class: cross entropy between Gumbel distributions.
282    """
283    def __init__(self):
284        super(CrossEntropy, self).__init__()
285        self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32)
286
287    def construct(self, x_, y_):
288        entropy = self.gum.entropy()
289        kl_loss = self.gum.kl_loss('Gumbel', x_, y_)
290        h_sum_kl = entropy + kl_loss
291        cross_entropy = self.gum.cross_entropy('Gumbel', x_, y_)
292        return h_sum_kl - cross_entropy
293
294def test_cross_entropy():
295    """
296    Test cross_entropy.
297    """
298    cross_entropy = CrossEntropy()
299    loc = Tensor([1.0], dtype=dtype.float32)
300    scale = Tensor([1.0], dtype=dtype.float32)
301    diff = cross_entropy(loc, scale)
302    tol = 1e-6
303    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
304