1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.Bernoulli. 17""" 18import pytest 19 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import dtype 23from mindspore import Tensor 24 25 26def test_arguments(): 27 """ 28 Args passing during initialization. 29 """ 30 b = msd.Bernoulli() 31 assert isinstance(b, msd.Distribution) 32 b = msd.Bernoulli([0.1, 0.3, 0.5, 0.9], dtype=dtype.int32) 33 assert isinstance(b, msd.Distribution) 34 35 36def test_type(): 37 with pytest.raises(TypeError): 38 msd.Bernoulli([0.1], dtype=dtype.bool_) 39 40 41def test_name(): 42 with pytest.raises(TypeError): 43 msd.Bernoulli([0.1], name=1.0) 44 45 46def test_seed(): 47 with pytest.raises(TypeError): 48 msd.Bernoulli([0.1], seed='seed') 49 50 51def test_prob(): 52 """ 53 Invalid probability. 54 """ 55 with pytest.raises(ValueError): 56 msd.Bernoulli([-0.1], dtype=dtype.int32) 57 with pytest.raises(ValueError): 58 msd.Bernoulli([1.1], dtype=dtype.int32) 59 with pytest.raises(ValueError): 60 msd.Bernoulli([0.0], dtype=dtype.int32) 61 with pytest.raises(ValueError): 62 msd.Bernoulli([1.0], dtype=dtype.int32) 63 64 65class BernoulliProb(nn.Cell): 66 """ 67 Bernoulli distribution: initialize with probs. 68 """ 69 70 def __init__(self): 71 super(BernoulliProb, self).__init__() 72 self.b = msd.Bernoulli(0.5, dtype=dtype.int32) 73 74 def construct(self, value): 75 prob = self.b.prob(value) 76 log_prob = self.b.log_prob(value) 77 cdf = self.b.cdf(value) 78 log_cdf = self.b.log_cdf(value) 79 sf = self.b.survival_function(value) 80 log_sf = self.b.log_survival(value) 81 return prob + log_prob + cdf + log_cdf + sf + log_sf 82 83 84def test_bernoulli_prob(): 85 """ 86 Test probability functions: passing value through construct. 87 """ 88 net = BernoulliProb() 89 value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) 90 ans = net(value) 91 assert isinstance(ans, Tensor) 92 93 94class BernoulliProb1(nn.Cell): 95 """ 96 Bernoulli distribution: initialize without probs. 97 """ 98 99 def __init__(self): 100 super(BernoulliProb1, self).__init__() 101 self.b = msd.Bernoulli(dtype=dtype.int32) 102 103 def construct(self, value, probs): 104 prob = self.b.prob(value, probs) 105 log_prob = self.b.log_prob(value, probs) 106 cdf = self.b.cdf(value, probs) 107 log_cdf = self.b.log_cdf(value, probs) 108 sf = self.b.survival_function(value, probs) 109 log_sf = self.b.log_survival(value, probs) 110 return prob + log_prob + cdf + log_cdf + sf + log_sf 111 112 113def test_bernoulli_prob1(): 114 """ 115 Test probability functions: passing value/probs through construct. 116 """ 117 net = BernoulliProb1() 118 value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) 119 probs = Tensor([0.5], dtype=dtype.float32) 120 ans = net(value, probs) 121 assert isinstance(ans, Tensor) 122 123 124class BernoulliKl(nn.Cell): 125 """ 126 Test class: kl_loss between Bernoulli distributions. 127 """ 128 129 def __init__(self): 130 super(BernoulliKl, self).__init__() 131 self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) 132 self.b2 = msd.Bernoulli(dtype=dtype.int32) 133 134 def construct(self, probs_b, probs_a): 135 kl1 = self.b1.kl_loss('Bernoulli', probs_b) 136 kl2 = self.b2.kl_loss('Bernoulli', probs_b, probs_a) 137 return kl1 + kl2 138 139 140def test_kl(): 141 """ 142 Test kl_loss function. 143 """ 144 ber_net = BernoulliKl() 145 probs_b = Tensor([0.3], dtype=dtype.float32) 146 probs_a = Tensor([0.7], dtype=dtype.float32) 147 ans = ber_net(probs_b, probs_a) 148 assert isinstance(ans, Tensor) 149 150 151class BernoulliCrossEntropy(nn.Cell): 152 """ 153 Test class: cross_entropy of Bernoulli distribution. 154 """ 155 156 def __init__(self): 157 super(BernoulliCrossEntropy, self).__init__() 158 self.b1 = msd.Bernoulli(0.7, dtype=dtype.int32) 159 self.b2 = msd.Bernoulli(dtype=dtype.int32) 160 161 def construct(self, probs_b, probs_a): 162 h1 = self.b1.cross_entropy('Bernoulli', probs_b) 163 h2 = self.b2.cross_entropy('Bernoulli', probs_b, probs_a) 164 return h1 + h2 165 166 167def test_cross_entropy(): 168 """ 169 Test cross_entropy between Bernoulli distributions. 170 """ 171 net = BernoulliCrossEntropy() 172 probs_b = Tensor([0.3], dtype=dtype.float32) 173 probs_a = Tensor([0.7], dtype=dtype.float32) 174 ans = net(probs_b, probs_a) 175 assert isinstance(ans, Tensor) 176 177 178class BernoulliConstruct(nn.Cell): 179 """ 180 Bernoulli distribution: going through construct. 181 """ 182 183 def __init__(self): 184 super(BernoulliConstruct, self).__init__() 185 self.b = msd.Bernoulli(0.5, dtype=dtype.int32) 186 self.b1 = msd.Bernoulli(dtype=dtype.int32) 187 188 def construct(self, value, probs): 189 prob = self.b('prob', value) 190 prob1 = self.b('prob', value, probs) 191 prob2 = self.b1('prob', value, probs) 192 return prob + prob1 + prob2 193 194 195def test_bernoulli_construct(): 196 """ 197 Test probability function going through construct. 198 """ 199 net = BernoulliConstruct() 200 value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) 201 probs = Tensor([0.5], dtype=dtype.float32) 202 ans = net(value, probs) 203 assert isinstance(ans, Tensor) 204 205 206class BernoulliMean(nn.Cell): 207 """ 208 Test class: basic mean/sd/var/mode/entropy function. 209 """ 210 211 def __init__(self): 212 super(BernoulliMean, self).__init__() 213 self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) 214 215 def construct(self): 216 mean = self.b.mean() 217 return mean 218 219 220def test_mean(): 221 """ 222 Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. 223 """ 224 net = BernoulliMean() 225 ans = net() 226 assert isinstance(ans, Tensor) 227 228 229class BernoulliSd(nn.Cell): 230 """ 231 Test class: basic mean/sd/var/mode/entropy function. 232 """ 233 234 def __init__(self): 235 super(BernoulliSd, self).__init__() 236 self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) 237 238 def construct(self): 239 sd = self.b.sd() 240 return sd 241 242 243def test_sd(): 244 """ 245 Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. 246 """ 247 net = BernoulliSd() 248 ans = net() 249 assert isinstance(ans, Tensor) 250 251 252class BernoulliVar(nn.Cell): 253 """ 254 Test class: basic mean/sd/var/mode/entropy function. 255 """ 256 257 def __init__(self): 258 super(BernoulliVar, self).__init__() 259 self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) 260 261 def construct(self): 262 var = self.b.var() 263 return var 264 265 266def test_var(): 267 """ 268 Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. 269 """ 270 net = BernoulliVar() 271 ans = net() 272 assert isinstance(ans, Tensor) 273 274 275class BernoulliMode(nn.Cell): 276 """ 277 Test class: basic mean/sd/var/mode/entropy function. 278 """ 279 280 def __init__(self): 281 super(BernoulliMode, self).__init__() 282 self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) 283 284 def construct(self): 285 mode = self.b.mode() 286 return mode 287 288 289def test_mode(): 290 """ 291 Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. 292 """ 293 net = BernoulliMode() 294 ans = net() 295 assert isinstance(ans, Tensor) 296 297 298class BernoulliEntropy(nn.Cell): 299 """ 300 Test class: basic mean/sd/var/mode/entropy function. 301 """ 302 303 def __init__(self): 304 super(BernoulliEntropy, self).__init__() 305 self.b = msd.Bernoulli([0.3, 0.5], dtype=dtype.int32) 306 307 def construct(self): 308 entropy = self.b.entropy() 309 return entropy 310 311 312def test_entropy(): 313 """ 314 Test mean/sd/var/mode/entropy functionality of Bernoulli distribution. 315 """ 316 net = BernoulliEntropy() 317 ans = net() 318 assert isinstance(ans, Tensor) 319