1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.Poisson. 17""" 18import pytest 19 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import dtype 23from mindspore import Tensor 24 25 26def test_arguments(): 27 """ 28 Args passing during initialization. 29 """ 30 p = msd.Poisson() 31 assert isinstance(p, msd.Distribution) 32 p = msd.Poisson([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) 33 assert isinstance(p, msd.Distribution) 34 35def test_type(): 36 with pytest.raises(TypeError): 37 msd.Poisson([0.1], dtype=dtype.bool_) 38 39def test_name(): 40 with pytest.raises(TypeError): 41 msd.Poisson([0.1], name=1.0) 42 43def test_seed(): 44 with pytest.raises(TypeError): 45 msd.Poisson([0.1], seed='seed') 46 47def test_rate(): 48 """ 49 Invalid rate. 50 """ 51 with pytest.raises(ValueError): 52 msd.Poisson([-0.1], dtype=dtype.float32) 53 with pytest.raises(ValueError): 54 msd.Poisson([0.0], dtype=dtype.float32) 55 56def test_scalar(): 57 with pytest.raises(TypeError): 58 msd.Poisson(0.1, seed='seed') 59 60class PoissonProb(nn.Cell): 61 """ 62 Poisson distribution: initialize with rate. 63 """ 64 def __init__(self): 65 super(PoissonProb, self).__init__() 66 self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) 67 68 def construct(self, value): 69 prob = self.p.prob(value) 70 log_prob = self.p.log_prob(value) 71 cdf = self.p.cdf(value) 72 log_cdf = self.p.log_cdf(value) 73 sf = self.p.survival_function(value) 74 log_sf = self.p.log_survival(value) 75 return prob + log_prob + cdf + log_cdf + sf + log_sf 76 77def test_poisson_prob(): 78 """ 79 Test probability functions: passing value through construct. 80 """ 81 net = PoissonProb() 82 value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) 83 ans = net(value) 84 assert isinstance(ans, Tensor) 85 86class PoissonProb1(nn.Cell): 87 """ 88 Poisson distribution: initialize without rate. 89 """ 90 def __init__(self): 91 super(PoissonProb1, self).__init__() 92 self.p = msd.Poisson(dtype=dtype.float32) 93 94 def construct(self, value, rate): 95 prob = self.p.prob(value, rate) 96 log_prob = self.p.log_prob(value, rate) 97 cdf = self.p.cdf(value, rate) 98 log_cdf = self.p.log_cdf(value, rate) 99 sf = self.p.survival_function(value, rate) 100 log_sf = self.p.log_survival(value, rate) 101 return prob + log_prob + cdf + log_cdf + sf + log_sf 102 103def test_poisson_prob1(): 104 """ 105 Test probability functions: passing value/rate through construct. 106 """ 107 net = PoissonProb1() 108 value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) 109 rate = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) 110 ans = net(value, rate) 111 assert isinstance(ans, Tensor) 112 113class PoissonBasics(nn.Cell): 114 """ 115 Test class: basic mean/sd/var/mode function. 116 """ 117 def __init__(self): 118 super(PoissonBasics, self).__init__() 119 self.p = msd.Poisson([2.3, 2.5], dtype=dtype.float32) 120 121 def construct(self): 122 mean = self.p.mean() 123 sd = self.p.sd() 124 var = self.p.var() 125 return mean + sd + var 126 127def test_bascis(): 128 """ 129 Test mean/sd/var/mode functionality of Poisson distribution. 130 """ 131 net = PoissonBasics() 132 ans = net() 133 assert isinstance(ans, Tensor) 134 135class PoissonConstruct(nn.Cell): 136 """ 137 Poisson distribution: going through construct. 138 """ 139 def __init__(self): 140 super(PoissonConstruct, self).__init__() 141 self.p = msd.Poisson([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) 142 self.p1 = msd.Poisson(dtype=dtype.float32) 143 144 def construct(self, value, rate): 145 prob = self.p('prob', value) 146 prob1 = self.p('prob', value, rate) 147 prob2 = self.p1('prob', value, rate) 148 return prob + prob1 + prob2 149 150def test_poisson_construct(): 151 """ 152 Test probability function going through construct. 153 """ 154 net = PoissonConstruct() 155 value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) 156 probs = Tensor([0.5, 0.5, 0.5, 0.5, 0.5], dtype=dtype.float32) 157 ans = net(value, probs) 158 assert isinstance(ans, Tensor) 159