• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Gamma distribution"""
16import numpy as np
17from scipy import stats
18from scipy import special
19import mindspore.context as context
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import Tensor
23from mindspore import dtype
24
25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
26
27class Prob(nn.Cell):
28    """
29    Test class: probability of Gamma distribution.
30    """
31    def __init__(self):
32        super(Prob, self).__init__()
33        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
34
35    def construct(self, x_):
36        return self.g.prob(x_)
37
38def test_pdf():
39    """
40    Test pdf.
41    """
42    gamma_benchmark = stats.gamma(np.array([3.0]))
43    expect_pdf = gamma_benchmark.pdf([1.0, 2.0]).astype(np.float32)
44    pdf = Prob()
45    output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
46    tol = 1e-6
47    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
48
49class LogProb(nn.Cell):
50    """
51    Test class: log probability of Gamma distribution.
52    """
53    def __init__(self):
54        super(LogProb, self).__init__()
55        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
56
57    def construct(self, x_):
58        return self.g.log_prob(x_)
59
60def test_log_likelihood():
61    """
62    Test log_pdf.
63    """
64    gamma_benchmark = stats.gamma(np.array([3.0]))
65    expect_logpdf = gamma_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
66    logprob = LogProb()
67    output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
68    tol = 1e-6
69    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
70
71
72class KL(nn.Cell):
73    """
74    Test class: kl_loss of Gamma distribution.
75    """
76    def __init__(self):
77        super(KL, self).__init__()
78        self.g = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
79
80    def construct(self, x_, y_):
81        return self.g.kl_loss('Gamma', x_, y_)
82
83
84def test_kl_loss():
85    """
86    Test kl_loss.
87    """
88    concentration_a = np.array([3.0]).astype(np.float32)
89    rate_a = np.array([4.0]).astype(np.float32)
90
91    concentration_b = np.array([1.0]).astype(np.float32)
92    rate_b = np.array([1.0]).astype(np.float32)
93
94    expect_kl_loss = (concentration_a - concentration_b) * special.digamma(concentration_a) \
95                     + special.gammaln(concentration_b) - special.gammaln(concentration_a) \
96                     + concentration_b * np.log(rate_a) - concentration_b * np.log(rate_b) \
97                     + concentration_a * (rate_b / rate_a - 1.)
98
99    kl_loss = KL()
100    concentration = Tensor(concentration_b, dtype=dtype.float32)
101    rate = Tensor(rate_b, dtype=dtype.float32)
102    output = kl_loss(concentration, rate)
103    tol = 1e-6
104    assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all()
105
106class Basics(nn.Cell):
107    """
108    Test class: mean/sd/mode of Gamma distribution.
109    """
110    def __init__(self):
111        super(Basics, self).__init__()
112        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
113
114    def construct(self):
115        return self.g.mean(), self.g.sd(), self.g.mode()
116
117def test_basics():
118    """
119    Test mean/standard deviation/mode.
120    """
121    basics = Basics()
122    mean, sd, mode = basics()
123    gamma_benchmark = stats.gamma(np.array([3.0]))
124    expect_mean = gamma_benchmark.mean().astype(np.float32)
125    expect_sd = gamma_benchmark.std().astype(np.float32)
126    expect_mode = [2.0]
127    tol = 1e-6
128    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
129    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
130    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
131
132class Sampling(nn.Cell):
133    """
134    Test class: sample of Gamma distribution.
135    """
136    def __init__(self, shape, seed=0):
137        super(Sampling, self).__init__()
138        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), seed=seed, dtype=dtype.float32)
139        self.shape = shape
140
141    def construct(self, concentration=None, rate=None):
142        return self.g.sample(self.shape, concentration, rate)
143
144def test_sample():
145    """
146    Test sample.
147    """
148    shape = (2, 3)
149    seed = 10
150    concentration = Tensor([2.0], dtype=dtype.float32)
151    rate = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
152    sample = Sampling(shape, seed=seed)
153    output = sample(concentration, rate)
154    assert output.shape == (2, 3, 3)
155
156class CDF(nn.Cell):
157    """
158    Test class: cdf of Gamma distribution.
159    """
160    def __init__(self):
161        super(CDF, self).__init__()
162        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
163
164    def construct(self, x_):
165        return self.g.cdf(x_)
166
167
168def test_cdf():
169    """
170    Test cdf.
171    """
172    gamma_benchmark = stats.gamma(np.array([3.0]))
173    expect_cdf = gamma_benchmark.cdf([2.0]).astype(np.float32)
174    cdf = CDF()
175    output = cdf(Tensor([2.0], dtype=dtype.float32))
176    tol = 2e-5
177    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
178
179class LogCDF(nn.Cell):
180    """
181    Test class: log_cdf of Mormal distribution.
182    """
183    def __init__(self):
184        super(LogCDF, self).__init__()
185        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
186
187    def construct(self, x_):
188        return self.g.log_cdf(x_)
189
190def test_log_cdf():
191    """
192    Test log cdf.
193    """
194    gamma_benchmark = stats.gamma(np.array([3.0]))
195    expect_logcdf = gamma_benchmark.logcdf([2.0]).astype(np.float32)
196    logcdf = LogCDF()
197    output = logcdf(Tensor([2.0], dtype=dtype.float32))
198    tol = 5e-5
199    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
200
201class SF(nn.Cell):
202    """
203    Test class: survival function of Gamma distribution.
204    """
205    def __init__(self):
206        super(SF, self).__init__()
207        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
208
209    def construct(self, x_):
210        return self.g.survival_function(x_)
211
212def test_survival():
213    """
214    Test log_survival.
215    """
216    gamma_benchmark = stats.gamma(np.array([3.0]))
217    expect_survival = gamma_benchmark.sf([2.0]).astype(np.float32)
218    survival_function = SF()
219    output = survival_function(Tensor([2.0], dtype=dtype.float32))
220    tol = 2e-5
221    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
222
223class LogSF(nn.Cell):
224    """
225    Test class: log survival function of Gamma distribution.
226    """
227    def __init__(self):
228        super(LogSF, self).__init__()
229        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
230
231    def construct(self, x_):
232        return self.g.log_survival(x_)
233
234def test_log_survival():
235    """
236    Test log_survival.
237    """
238    gamma_benchmark = stats.gamma(np.array([3.0]))
239    expect_log_survival = gamma_benchmark.logsf([2.0]).astype(np.float32)
240    log_survival = LogSF()
241    output = log_survival(Tensor([2.0], dtype=dtype.float32))
242    tol = 2e-5
243    assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
244
245class EntropyH(nn.Cell):
246    """
247    Test class: entropy of Gamma distribution.
248    """
249    def __init__(self):
250        super(EntropyH, self).__init__()
251        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
252
253    def construct(self):
254        return self.g.entropy()
255
256def test_entropy():
257    """
258    Test entropy.
259    """
260    gamma_benchmark = stats.gamma(np.array([3.0]))
261    expect_entropy = gamma_benchmark.entropy().astype(np.float32)
262    entropy = EntropyH()
263    output = entropy()
264    tol = 1e-6
265    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
266
267class CrossEntropy(nn.Cell):
268    """
269    Test class: cross entropy between Gamma distributions.
270    """
271    def __init__(self):
272        super(CrossEntropy, self).__init__()
273        self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
274
275    def construct(self, x_, y_):
276        entropy = self.g.entropy()
277        kl_loss = self.g.kl_loss('Gamma', x_, y_)
278        h_sum_kl = entropy + kl_loss
279        cross_entropy = self.g.cross_entropy('Gamma', x_, y_)
280        return h_sum_kl - cross_entropy
281
282def test_cross_entropy():
283    """
284    Test cross_entropy.
285    """
286    cross_entropy = CrossEntropy()
287    concentration = Tensor([3.0], dtype=dtype.float32)
288    rate = Tensor([2.0], dtype=dtype.float32)
289    diff = cross_entropy(concentration, rate)
290    tol = 1e-6
291    assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all()
292
293class Net(nn.Cell):
294    """
295    Test class: expand single distribution instance to multiple graphs
296    by specifying the attributes.
297    """
298
299    def __init__(self):
300        super(Net, self).__init__()
301        self.get_flags = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32)
302
303    def construct(self, x_, y_):
304        kl = self.g.kl_loss('Gamma', x_, y_)
305        prob = self.g.prob(kl)
306        return prob
307
308def test_multiple_graphs():
309    """
310    Test multiple graphs case.
311    """
312    prob = Net()
313    concentration_a = np.array([3.0]).astype(np.float32)
314    rate_a = np.array([1.0]).astype(np.float32)
315    concentration_b = np.array([2.0]).astype(np.float32)
316    rate_b = np.array([1.0]).astype(np.float32)
317    ans = prob(Tensor(concentration_b), Tensor(rate_b))
318
319    expect_kl_loss = (concentration_a - concentration_b) * special.digamma(concentration_a) \
320                     + special.gammaln(concentration_b) - special.gammaln(concentration_a) \
321                     + concentration_b * np.log(rate_a) - concentration_b * np.log(rate_b) \
322                     + concentration_a * (rate_b / rate_a - 1.)
323
324    gamma_benchmark = stats.gamma(np.array([3.0]))
325    expect_prob = gamma_benchmark.pdf(expect_kl_loss).astype(np.float32)
326
327    tol = 1e-6
328    assert (np.abs(ans.asnumpy() - expect_prob) < tol).all()
329