1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for cat distribution""" 16import numpy as np 17import pytest 18from scipy import stats 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import Tensor 23from mindspore import dtype 24 25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 26 27 28class Prob(nn.Cell): 29 """ 30 Test class: probability of categorical distribution. 31 """ 32 33 def __init__(self): 34 super(Prob, self).__init__() 35 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 36 37 def construct(self, x_): 38 return self.c.prob(x_) 39 40 41def test_pmf(): 42 """ 43 Test pmf. 44 """ 45 expect_pmf = [0.7, 0.3, 0.7, 0.3, 0.3] 46 pmf = Prob() 47 x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype( 48 np.int32), dtype=dtype.float32) 49 output = pmf(x_) 50 tol = 1e-6 51 assert (np.abs(output.asnumpy() - expect_pmf) < tol).all() 52 53 54class LogProb(nn.Cell): 55 """ 56 Test class: log probability of categorical distribution. 57 """ 58 59 def __init__(self): 60 super(LogProb, self).__init__() 61 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 62 63 def construct(self, x_): 64 return self.c.log_prob(x_) 65 66 67def test_log_likelihood(): 68 """ 69 Test log_pmf. 70 """ 71 expect_logpmf = np.log([0.7, 0.3, 0.7, 0.3, 0.3]) 72 logprob = LogProb() 73 x_ = Tensor(np.array([0, 1, 0, 1, 1]).astype( 74 np.int32), dtype=dtype.float32) 75 output = logprob(x_) 76 tol = 1e-6 77 assert (np.abs(output.asnumpy() - expect_logpmf) < tol).all() 78 79 80class KL(nn.Cell): 81 """ 82 Test class: kl_loss between categorical distributions. 83 """ 84 85 def __init__(self): 86 super(KL, self).__init__() 87 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 88 89 def construct(self, x_): 90 return self.c.kl_loss('Categorical', x_) 91 92 93def test_kl_loss(): 94 """ 95 Test kl_loss. 96 """ 97 kl_loss = KL() 98 output = kl_loss(Tensor([0.7, 0.3], dtype=dtype.float32)) 99 tol = 1e-6 100 assert (np.abs(output.asnumpy()) < tol).all() 101 102 103class Sampling(nn.Cell): 104 """ 105 Test class: sampling of categorical distribution. 106 """ 107 108 def __init__(self): 109 super(Sampling, self).__init__() 110 self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) 111 self.shape = (2, 3) 112 113 def construct(self): 114 return self.c.sample(self.shape) 115 116 117def test_sample(): 118 """ 119 Test sample. 120 """ 121 with pytest.raises(NotImplementedError): 122 sample = Sampling() 123 sample() 124 125 126class Basics(nn.Cell): 127 """ 128 Test class: mean/var/mode of categorical distribution. 129 """ 130 131 def __init__(self): 132 super(Basics, self).__init__() 133 self.c = msd.Categorical([0.2, 0.1, 0.7], dtype=dtype.int32) 134 135 def construct(self): 136 return self.c.mean(), self.c.var(), self.c.mode() 137 138 139def test_basics(): 140 """ 141 Test mean/variance/mode. 142 """ 143 basics = Basics() 144 mean, var, mode = basics() 145 expect_mean = 0 * 0.2 + 1 * 0.1 + 2 * 0.7 146 expect_var = 0 * 0.2 + 1 * 0.1 + 4 * 0.7 - (expect_mean * expect_mean) 147 expect_mode = 2 148 tol = 1e-6 149 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 150 assert (np.abs(var.asnumpy() - expect_var) < tol).all() 151 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 152 153 154class CDF(nn.Cell): 155 """ 156 Test class: cdf of categorical distributions. 157 """ 158 159 def __init__(self): 160 super(CDF, self).__init__() 161 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 162 163 def construct(self, x_): 164 return self.c.cdf(x_) 165 166 167def test_cdf(): 168 """ 169 Test cdf. 170 """ 171 expect_cdf = [0.7, 0.7, 1, 0.7, 1] 172 x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype( 173 np.int32), dtype=dtype.float32) 174 cdf = CDF() 175 output = cdf(x_) 176 tol = 1e-6 177 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 178 179 180class LogCDF(nn.Cell): 181 """ 182 Test class: log cdf of categorical distributions. 183 """ 184 185 def __init__(self): 186 super(LogCDF, self).__init__() 187 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 188 189 def construct(self, x_): 190 return self.c.log_cdf(x_) 191 192 193def test_logcdf(): 194 """ 195 Test log_cdf. 196 """ 197 expect_logcdf = np.log([0.7, 0.7, 1, 0.7, 1]) 198 x_ = Tensor(np.array([0, 0, 1, 0, 1]).astype( 199 np.int32), dtype=dtype.float32) 200 logcdf = LogCDF() 201 output = logcdf(x_) 202 tol = 1e-6 203 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 204 205 206class SF(nn.Cell): 207 """ 208 Test class: survival function of categorical distributions. 209 """ 210 211 def __init__(self): 212 super(SF, self).__init__() 213 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 214 215 def construct(self, x_): 216 return self.c.survival_function(x_) 217 218 219def test_survival(): 220 """ 221 Test survival function. 222 """ 223 expect_survival = [0.3, 0., 0., 0.3, 0.3] 224 x_ = Tensor(np.array([0, 1, 1, 0, 0]).astype( 225 np.int32), dtype=dtype.float32) 226 sf = SF() 227 output = sf(x_) 228 tol = 1e-6 229 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 230 231 232class LogSF(nn.Cell): 233 """ 234 Test class: log survival function of categorical distributions. 235 """ 236 237 def __init__(self): 238 super(LogSF, self).__init__() 239 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 240 241 def construct(self, x_): 242 return self.c.log_survival(x_) 243 244 245def test_log_survival(): 246 """ 247 Test log survival function. 248 """ 249 expect_logsurvival = np.log([1., 0.3, 0.3, 0.3, 0.3]) 250 x_ = Tensor(np.array([-2, 0, 0, 0.5, 0.5] 251 ).astype(np.float32), dtype=dtype.float32) 252 log_sf = LogSF() 253 output = log_sf(x_) 254 tol = 1e-6 255 assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() 256 257 258class EntropyH(nn.Cell): 259 """ 260 Test class: entropy of categorical distributions. 261 """ 262 263 def __init__(self): 264 super(EntropyH, self).__init__() 265 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 266 267 def construct(self): 268 return self.c.entropy() 269 270 271def test_entropy(): 272 """ 273 Test entropy. 274 """ 275 cat_benchmark = stats.multinomial(n=1, p=[0.7, 0.3]) 276 expect_entropy = cat_benchmark.entropy().astype(np.float32) 277 entropy = EntropyH() 278 output = entropy() 279 tol = 1e-6 280 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 281 282 283class CrossEntropy(nn.Cell): 284 """ 285 Test class: cross entropy between categorical distributions. 286 """ 287 288 def __init__(self): 289 super(CrossEntropy, self).__init__() 290 self.c = msd.Categorical([0.7, 0.3], dtype=dtype.int32) 291 292 def construct(self, x_): 293 entropy = self.c.entropy() 294 kl_loss = self.c.kl_loss('Categorical', x_) 295 h_sum_kl = entropy + kl_loss 296 cross_entropy = self.c.cross_entropy('Categorical', x_) 297 return h_sum_kl - cross_entropy 298 299 300def test_cross_entropy(): 301 """ 302 Test cross_entropy. 303 """ 304 cross_entropy = CrossEntropy() 305 prob = Tensor([0.7, 0.3], dtype=dtype.float32) 306 diff = cross_entropy(prob) 307 tol = 1e-6 308 assert (np.abs(diff.asnumpy()) < tol).all() 309