• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Poisson distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26class Prob(nn.Cell):
27    """
28    Test class: probability of Poisson distribution.
29    """
30    def __init__(self):
31        super(Prob, self).__init__()
32        self.p = msd.Poisson([0.5], dtype=dtype.float32)
33
34    def construct(self, x_):
35        return self.p.prob(x_)
36
37def test_pdf():
38    """
39    Test pdf.
40    """
41    poisson_benchmark = stats.poisson(mu=0.5)
42    expect_pdf = poisson_benchmark.pmf([-1.0, 0.0, 1.0]).astype(np.float32)
43    pdf = Prob()
44    x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
45    output = pdf(x_)
46    tol = 1e-6
47    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
48
49class LogProb(nn.Cell):
50    """
51    Test class: log probability of Poisson distribution.
52    """
53    def __init__(self):
54        super(LogProb, self).__init__()
55        self.p = msd.Poisson([0.5], dtype=dtype.float32)
56
57    def construct(self, x_):
58        return self.p.log_prob(x_)
59
60def test_log_likelihood():
61    """
62    Test log_pdf.
63    """
64    poisson_benchmark = stats.poisson(mu=0.5)
65    expect_logpdf = poisson_benchmark.logpmf([1.0, 2.0]).astype(np.float32)
66    logprob = LogProb()
67    x_ = Tensor(np.array([1.0, 2.0]).astype(np.float32), dtype=dtype.float32)
68    output = logprob(x_)
69    tol = 1e-6
70    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
71
72class Basics(nn.Cell):
73    """
74    Test class: mean/sd/mode of Poisson distribution.
75    """
76    def __init__(self):
77        super(Basics, self).__init__()
78        self.p = msd.Poisson([1.44], dtype=dtype.float32)
79
80    def construct(self):
81        return self.p.mean(), self.p.sd(), self.p.mode()
82
83def test_basics():
84    """
85    Test mean/standard/mode deviation.
86    """
87    basics = Basics()
88    mean, sd, mode = basics()
89    expect_mean = 1.44
90    expect_sd = 1.2
91    expect_mode = 1
92    tol = 1e-6
93    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
94    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
95    assert (np.abs(mode.asnumpy() - expect_mode) < tol).all()
96
97class Sampling(nn.Cell):
98    """
99    Test class: sample of Poisson distribution.
100    """
101    def __init__(self, shape, seed=0):
102        super(Sampling, self).__init__()
103        self.p = msd.Poisson([[1.0], [0.5]], seed=seed, dtype=dtype.float32)
104        self.shape = shape
105
106    def construct(self, rate=None):
107        return self.p.sample(self.shape, rate)
108
109def test_sample():
110    """
111    Test sample.
112    """
113    shape = (2, 3)
114    seed = 10
115    rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32)
116    sample = Sampling(shape, seed=seed)
117    output = sample(rate)
118    assert output.shape == (2, 3, 3)
119
120class CDF(nn.Cell):
121    """
122    Test class: cdf of Poisson distribution.
123    """
124    def __init__(self):
125        super(CDF, self).__init__()
126        self.p = msd.Poisson([0.5], dtype=dtype.float32)
127
128    def construct(self, x_):
129        return self.p.cdf(x_)
130
131def test_cdf():
132    """
133    Test cdf.
134    """
135    poisson_benchmark = stats.poisson(mu=0.5)
136    expect_cdf = poisson_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32)
137    cdf = CDF()
138    x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
139    output = cdf(x_)
140    tol = 1e-6
141    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
142
143class LogCDF(nn.Cell):
144    """
145    Test class: log_cdf of Poisson distribution.
146    """
147    def __init__(self):
148        super(LogCDF, self).__init__()
149        self.p = msd.Poisson([0.5], dtype=dtype.float32)
150
151    def construct(self, x_):
152        return self.p.log_cdf(x_)
153
154def test_log_cdf():
155    """
156    Test log_cdf.
157    """
158    poisson_benchmark = stats.poisson(mu=0.5)
159    expect_logcdf = poisson_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32)
160    logcdf = LogCDF()
161    x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32)
162    output = logcdf(x_)
163    tol = 1e-6
164    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
165
166class SF(nn.Cell):
167    """
168    Test class: survival function of Poisson distribution.
169    """
170    def __init__(self):
171        super(SF, self).__init__()
172        self.p = msd.Poisson([0.5], dtype=dtype.float32)
173
174    def construct(self, x_):
175        return self.p.survival_function(x_)
176
177def test_survival():
178    """
179    Test survival function.
180    """
181    poisson_benchmark = stats.poisson(mu=0.5)
182    expect_survival = poisson_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32)
183    survival = SF()
184    x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
185    output = survival(x_)
186    tol = 1e-6
187    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
188
189class LogSF(nn.Cell):
190    """
191    Test class: log survival function of Poisson distribution.
192    """
193    def __init__(self):
194        super(LogSF, self).__init__()
195        self.p = msd.Poisson([0.5], dtype=dtype.float32)
196
197    def construct(self, x_):
198        return self.p.log_survival(x_)
199
200def test_log_survival():
201    """
202    Test log survival function.
203    """
204    poisson_benchmark = stats.poisson(mu=0.5)
205    expect_logsurvival = poisson_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32)
206    logsurvival = LogSF()
207    x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32)
208    output = logsurvival(x_)
209    tol = 1e-6
210    assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all()
211