1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Bernoulli distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26 27class Prob(nn.Cell): 28 """ 29 Test class: probability of Bernoulli distribution. 30 """ 31 32 def __init__(self): 33 super(Prob, self).__init__() 34 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 35 36 def construct(self, x_): 37 return self.b.prob(x_) 38 39 40def test_pmf(): 41 """ 42 Test pmf. 43 """ 44 bernoulli_benchmark = stats.bernoulli(0.7) 45 expect_pmf = bernoulli_benchmark.pmf([0, 1, 0, 1, 1]).astype(np.float32) 46 pmf = Prob() 47 x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype( 48 np.int32), dtype=dtype.float32) 49 output = pmf(x_) 50 tol = 1e-6 51 assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() 52 53 54class LogProb(nn.Cell): 55 """ 56 Test class: log probability of Bernoulli distribution. 57 """ 58 59 def __init__(self): 60 super(LogProb, self).__init__() 61 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 62 63 def construct(self, x_): 64 return self.b.log_prob(x_) 65 66 67def test_log_likelihood(): 68 """ 69 Test log_pmf. 70 """ 71 bernoulli_benchmark = stats.bernoulli(0.7) 72 expect_logpmf = bernoulli_benchmark.logpmf( 73 [0, 1, 0, 1, 1]).astype(np.float32) 74 logprob = LogProb() 75 x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype( 76 np.int32), dtype=dtype.float32) 77 output = logprob(x_) 78 tol = 1e-6 79 assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() 80 81 82class KL(nn.Cell): 83 """ 84 Test class: kl_loss between Bernoulli distributions. 85 """ 86 87 def __init__(self): 88 super(KL, self).__init__() 89 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 90 91 def construct(self, x_): 92 return self.b.kl_loss('Bernoulli', x_) 93 94 95def test_kl_loss(): 96 """ 97 Test kl_loss. 98 """ 99 probs1_a = 0.7 100 probs1_b = 0.5 101 probs0_a = 1 - probs1_a 102 probs0_b = 1 - probs1_b 103 expect_kl_loss = probs1_a * \ 104 np.log(probs1_a / probs1_b) + probs0_a * np.log(probs0_a / probs0_b) 105 kl_loss = KL() 106 output = kl_loss(Tensor([probs1_b], dtype=dtype.float32)) 107 tol = 1e-6 108 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 109 110 111class Basics(nn.Cell): 112 """ 113 Test class: mean/sd/mode of Bernoulli distribution. 114 """ 115 116 def __init__(self): 117 super(Basics, self).__init__() 118 self.b = msd.Bernoulli([0.3, 0.5, 0.7], dtype=dtype.int32) 119 120 def construct(self): 121 return self.b.mean(), self.b.sd(), self.b.mode() 122 123 124def test_basics(): 125 """ 126 Test mean/standard deviation/mode. 127 """ 128 basics = Basics() 129 mean, sd, mode = basics() 130 expect_mean = [0.3, 0.5, 0.7] 131 expect_sd = np.sqrt(np.multiply([0.7, 0.5, 0.3], [0.3, 0.5, 0.7])) 132 expect_mode = [0.0, 0.0, 1.0] 133 tol = 1e-6 134 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 135 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 136 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 137 138 139class Sampling(nn.Cell): 140 """ 141 Test class: log probability of Bernoulli distribution. 142 """ 143 144 def __init__(self, shape, seed=0): 145 super(Sampling, self).__init__() 146 self.b = msd.Bernoulli([0.7, 0.5], seed=seed, dtype=dtype.int32) 147 self.shape = shape 148 149 def construct(self, probs=None): 150 return self.b.sample(self.shape, probs) 151 152 153def test_sample(): 154 """ 155 Test sample. 156 """ 157 shape = (2, 3) 158 sample = Sampling(shape) 159 output = sample() 160 assert output.shape == (2, 3, 2) 161 162 163class CDF(nn.Cell): 164 """ 165 Test class: cdf of bernoulli distributions. 166 """ 167 168 def __init__(self): 169 super(CDF, self).__init__() 170 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 171 172 def construct(self, x_): 173 return self.b.cdf(x_) 174 175 176def test_cdf(): 177 """ 178 Test cdf. 179 """ 180 bernoulli_benchmark = stats.bernoulli(0.7) 181 expect_cdf = bernoulli_benchmark.cdf([0, 0, 1, 0, 1]).astype(np.float32) 182 x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype( 183 np.int32), dtype=dtype.float32) 184 cdf = CDF() 185 output = cdf(x_) 186 tol = 1e-6 187 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 188 189 190class LogCDF(nn.Cell): 191 """ 192 Test class: log cdf of bernoulli distributions. 193 """ 194 195 def __init__(self): 196 super(LogCDF, self).__init__() 197 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 198 199 def construct(self, x_): 200 return self.b.log_cdf(x_) 201 202 203def test_logcdf(): 204 """ 205 Test log_cdf. 206 """ 207 bernoulli_benchmark = stats.bernoulli(0.7) 208 expect_logcdf = bernoulli_benchmark.logcdf( 209 [0, 0, 1, 0, 1]).astype(np.float32) 210 x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype( 211 np.int32), dtype=dtype.float32) 212 logcdf = LogCDF() 213 output = logcdf(x_) 214 tol = 1e-6 215 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 216 217 218class SF(nn.Cell): 219 """ 220 Test class: survival function of Bernoulli distributions. 221 """ 222 223 def __init__(self): 224 super(SF, self).__init__() 225 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 226 227 def construct(self, x_): 228 return self.b.survival_function(x_) 229 230 231def test_survival(): 232 """ 233 Test survival function. 234 """ 235 bernoulli_benchmark = stats.bernoulli(0.7) 236 expect_survival = bernoulli_benchmark.sf( 237 [0, 1, 1, 0, 0]).astype(np.float32) 238 x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype( 239 np.int32), dtype=dtype.float32) 240 sf = SF() 241 output = sf(x_) 242 tol = 1e-6 243 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 244 245 246class LogSF(nn.Cell): 247 """ 248 Test class: log survival function of Bernoulli distributions. 249 """ 250 251 def __init__(self): 252 super(LogSF, self).__init__() 253 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 254 255 def construct(self, x_): 256 return self.b.log_survival(x_) 257 258 259def test_log_survival(): 260 """ 261 Test log survival function. 262 """ 263 bernoulli_benchmark = stats.bernoulli(0.7) 264 expect_logsurvival = bernoulli_benchmark.logsf( 265 [-1, 0.9, 0, 0, 0]).astype(np.float32) 266 x_ = Tensor(np.array([-1, 0.9, 0, 0, 0] 267 ).astype(np.float32), dtype=dtype.float32) 268 log_sf = LogSF() 269 output = log_sf(x_) 270 tol = 1e-6 271 assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() 272 273 274class EntropyH(nn.Cell): 275 """ 276 Test class: entropy of Bernoulli distributions. 277 """ 278 279 def __init__(self): 280 super(EntropyH, self).__init__() 281 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 282 283 def construct(self): 284 return self.b.entropy() 285 286 287def test_entropy(): 288 """ 289 Test entropy. 290 """ 291 bernoulli_benchmark = stats.bernoulli(0.7) 292 expect_entropy = bernoulli_benchmark.entropy().astype(np.float32) 293 entropy = EntropyH() 294 output = entropy() 295 tol = 1e-6 296 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 297 298 299class CrossEntropy(nn.Cell): 300 """ 301 Test class: cross entropy between bernoulli distributions. 302 """ 303 304 def __init__(self): 305 super(CrossEntropy, self).__init__() 306 self.b = msd.Bernoulli(0.7, dtype=dtype.int32) 307 308 def construct(self, x_): 309 entropy = self.b.entropy() 310 kl_loss = self.b.kl_loss('Bernoulli', x_) 311 h_sum_kl = entropy + kl_loss 312 cross_entropy = self.b.cross_entropy('Bernoulli', x_) 313 return h_sum_kl - cross_entropy 314 315 316def test_cross_entropy(): 317 """ 318 Test cross_entropy. 319 """ 320 cross_entropy = CrossEntropy() 321 prob = Tensor([0.3], dtype=dtype.float32) 322 diff = cross_entropy(prob) 323 tol = 1e-6 324 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 325