1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Poisson distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of Poisson distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.p = msd.Poisson([0.5], dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.p.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 poisson_benchmark = stats.poisson(mu=0.5) 42 expect_pdf = poisson_benchmark.pmf([-1.0, 0.0, 1.0]).astype(np.float32) 43 pdf = Prob() 44 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 45 output = pdf(x_) 46 tol = 1e-6 47 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 48 49class LogProb(nn.Cell): 50 """ 51 Test class: log probability of Poisson distribution. 52 """ 53 def __init__(self): 54 super(LogProb, self).__init__() 55 self.p = msd.Poisson([0.5], dtype=dtype.float32) 56 57 def construct(self, x_): 58 return self.p.log_prob(x_) 59 60def test_log_likelihood(): 61 """ 62 Test log_pdf. 63 """ 64 poisson_benchmark = stats.poisson(mu=0.5) 65 expect_logpdf = poisson_benchmark.logpmf([1.0, 2.0]).astype(np.float32) 66 logprob = LogProb() 67 x_ = Tensor(np.array([1.0, 2.0]).astype(np.float32), dtype=dtype.float32) 68 output = logprob(x_) 69 tol = 1e-6 70 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 71 72class Basics(nn.Cell): 73 """ 74 Test class: mean/sd/mode of Poisson distribution. 75 """ 76 def __init__(self): 77 super(Basics, self).__init__() 78 self.p = msd.Poisson([1.44], dtype=dtype.float32) 79 80 def construct(self): 81 return self.p.mean(), self.p.sd(), self.p.mode() 82 83def test_basics(): 84 """ 85 Test mean/standard/mode deviation. 86 """ 87 basics = Basics() 88 mean, sd, mode = basics() 89 expect_mean = 1.44 90 expect_sd = 1.2 91 expect_mode = 1 92 tol = 1e-6 93 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 94 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 95 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 96 97class Sampling(nn.Cell): 98 """ 99 Test class: sample of Poisson distribution. 100 """ 101 def __init__(self, shape, seed=0): 102 super(Sampling, self).__init__() 103 self.p = msd.Poisson([[1.0], [0.5]], seed=seed, dtype=dtype.float32) 104 self.shape = shape 105 106 def construct(self, rate=None): 107 return self.p.sample(self.shape, rate) 108 109def test_sample(): 110 """ 111 Test sample. 112 """ 113 shape = (2, 3) 114 seed = 10 115 rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) 116 sample = Sampling(shape, seed=seed) 117 output = sample(rate) 118 assert output.shape == (2, 3, 3) 119 120class CDF(nn.Cell): 121 """ 122 Test class: cdf of Poisson distribution. 123 """ 124 def __init__(self): 125 super(CDF, self).__init__() 126 self.p = msd.Poisson([0.5], dtype=dtype.float32) 127 128 def construct(self, x_): 129 return self.p.cdf(x_) 130 131def test_cdf(): 132 """ 133 Test cdf. 134 """ 135 poisson_benchmark = stats.poisson(mu=0.5) 136 expect_cdf = poisson_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) 137 cdf = CDF() 138 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 139 output = cdf(x_) 140 tol = 1e-6 141 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 142 143class LogCDF(nn.Cell): 144 """ 145 Test class: log_cdf of Poisson distribution. 146 """ 147 def __init__(self): 148 super(LogCDF, self).__init__() 149 self.p = msd.Poisson([0.5], dtype=dtype.float32) 150 151 def construct(self, x_): 152 return self.p.log_cdf(x_) 153 154def test_log_cdf(): 155 """ 156 Test log_cdf. 157 """ 158 poisson_benchmark = stats.poisson(mu=0.5) 159 expect_logcdf = poisson_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) 160 logcdf = LogCDF() 161 x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) 162 output = logcdf(x_) 163 tol = 1e-6 164 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 165 166class SF(nn.Cell): 167 """ 168 Test class: survival function of Poisson distribution. 169 """ 170 def __init__(self): 171 super(SF, self).__init__() 172 self.p = msd.Poisson([0.5], dtype=dtype.float32) 173 174 def construct(self, x_): 175 return self.p.survival_function(x_) 176 177def test_survival(): 178 """ 179 Test survival function. 180 """ 181 poisson_benchmark = stats.poisson(mu=0.5) 182 expect_survival = poisson_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) 183 survival = SF() 184 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 185 output = survival(x_) 186 tol = 1e-6 187 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 188 189class LogSF(nn.Cell): 190 """ 191 Test class: log survival function of Poisson distribution. 192 """ 193 def __init__(self): 194 super(LogSF, self).__init__() 195 self.p = msd.Poisson([0.5], dtype=dtype.float32) 196 197 def construct(self, x_): 198 return self.p.log_survival(x_) 199 200def test_log_survival(): 201 """ 202 Test log survival function. 203 """ 204 poisson_benchmark = stats.poisson(mu=0.5) 205 expect_logsurvival = poisson_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) 206 logsurvival = LogSF() 207 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 208 output = logsurvival(x_) 209 tol = 1e-6 210 assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() 211