1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Gamma distribution""" 16import numpy as np 17from scipy import stats 18from scipy import special 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import Tensor 23from mindspore import dtype 24 25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 26 27class Prob(nn.Cell): 28 """ 29 Test class: probability of Gamma distribution. 30 """ 31 def __init__(self): 32 super(Prob, self).__init__() 33 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 34 35 def construct(self, x_): 36 return self.g.prob(x_) 37 38def test_pdf(): 39 """ 40 Test pdf. 41 """ 42 gamma_benchmark = stats.gamma(np.array([3.0])) 43 expect_pdf = gamma_benchmark.pdf([1.0, 2.0]).astype(np.float32) 44 pdf = Prob() 45 output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 46 tol = 1e-6 47 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 48 49class LogProb(nn.Cell): 50 """ 51 Test class: log probability of Gamma distribution. 52 """ 53 def __init__(self): 54 super(LogProb, self).__init__() 55 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 56 57 def construct(self, x_): 58 return self.g.log_prob(x_) 59 60def test_log_likelihood(): 61 """ 62 Test log_pdf. 63 """ 64 gamma_benchmark = stats.gamma(np.array([3.0])) 65 expect_logpdf = gamma_benchmark.logpdf([1.0, 2.0]).astype(np.float32) 66 logprob = LogProb() 67 output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) 68 tol = 1e-6 69 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 70 71 72class KL(nn.Cell): 73 """ 74 Test class: kl_loss of Gamma distribution. 75 """ 76 def __init__(self): 77 super(KL, self).__init__() 78 self.g = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 79 80 def construct(self, x_, y_): 81 return self.g.kl_loss('Gamma', x_, y_) 82 83 84def test_kl_loss(): 85 """ 86 Test kl_loss. 87 """ 88 concentration_a = np.array([3.0]).astype(np.float32) 89 rate_a = np.array([4.0]).astype(np.float32) 90 91 concentration_b = np.array([1.0]).astype(np.float32) 92 rate_b = np.array([1.0]).astype(np.float32) 93 94 expect_kl_loss = (concentration_a - concentration_b) * special.digamma(concentration_a) \ 95 + special.gammaln(concentration_b) - special.gammaln(concentration_a) \ 96 + concentration_b * np.log(rate_a) - concentration_b * np.log(rate_b) \ 97 + concentration_a * (rate_b / rate_a - 1.) 98 99 kl_loss = KL() 100 concentration = Tensor(concentration_b, dtype=dtype.float32) 101 rate = Tensor(rate_b, dtype=dtype.float32) 102 output = kl_loss(concentration, rate) 103 tol = 1e-6 104 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 105 106class Basics(nn.Cell): 107 """ 108 Test class: mean/sd/mode of Gamma distribution. 109 """ 110 def __init__(self): 111 super(Basics, self).__init__() 112 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 113 114 def construct(self): 115 return self.g.mean(), self.g.sd(), self.g.mode() 116 117def test_basics(): 118 """ 119 Test mean/standard deviation/mode. 120 """ 121 basics = Basics() 122 mean, sd, mode = basics() 123 gamma_benchmark = stats.gamma(np.array([3.0])) 124 expect_mean = gamma_benchmark.mean().astype(np.float32) 125 expect_sd = gamma_benchmark.std().astype(np.float32) 126 expect_mode = [2.0] 127 tol = 1e-6 128 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 129 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 130 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 131 132class Sampling(nn.Cell): 133 """ 134 Test class: sample of Gamma distribution. 135 """ 136 def __init__(self, shape, seed=0): 137 super(Sampling, self).__init__() 138 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), seed=seed, dtype=dtype.float32) 139 self.shape = shape 140 141 def construct(self, concentration=None, rate=None): 142 return self.g.sample(self.shape, concentration, rate) 143 144def test_sample(): 145 """ 146 Test sample. 147 """ 148 shape = (2, 3) 149 seed = 10 150 concentration = Tensor([2.0], dtype=dtype.float32) 151 rate = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) 152 sample = Sampling(shape, seed=seed) 153 output = sample(concentration, rate) 154 assert output.shape == (2, 3, 3) 155 156class CDF(nn.Cell): 157 """ 158 Test class: cdf of Gamma distribution. 159 """ 160 def __init__(self): 161 super(CDF, self).__init__() 162 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 163 164 def construct(self, x_): 165 return self.g.cdf(x_) 166 167 168def test_cdf(): 169 """ 170 Test cdf. 171 """ 172 gamma_benchmark = stats.gamma(np.array([3.0])) 173 expect_cdf = gamma_benchmark.cdf([2.0]).astype(np.float32) 174 cdf = CDF() 175 output = cdf(Tensor([2.0], dtype=dtype.float32)) 176 tol = 2e-5 177 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 178 179class LogCDF(nn.Cell): 180 """ 181 Test class: log_cdf of Mormal distribution. 182 """ 183 def __init__(self): 184 super(LogCDF, self).__init__() 185 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 186 187 def construct(self, x_): 188 return self.g.log_cdf(x_) 189 190def test_log_cdf(): 191 """ 192 Test log cdf. 193 """ 194 gamma_benchmark = stats.gamma(np.array([3.0])) 195 expect_logcdf = gamma_benchmark.logcdf([2.0]).astype(np.float32) 196 logcdf = LogCDF() 197 output = logcdf(Tensor([2.0], dtype=dtype.float32)) 198 tol = 5e-5 199 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 200 201class SF(nn.Cell): 202 """ 203 Test class: survival function of Gamma distribution. 204 """ 205 def __init__(self): 206 super(SF, self).__init__() 207 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 208 209 def construct(self, x_): 210 return self.g.survival_function(x_) 211 212def test_survival(): 213 """ 214 Test log_survival. 215 """ 216 gamma_benchmark = stats.gamma(np.array([3.0])) 217 expect_survival = gamma_benchmark.sf([2.0]).astype(np.float32) 218 survival_function = SF() 219 output = survival_function(Tensor([2.0], dtype=dtype.float32)) 220 tol = 2e-5 221 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 222 223class LogSF(nn.Cell): 224 """ 225 Test class: log survival function of Gamma distribution. 226 """ 227 def __init__(self): 228 super(LogSF, self).__init__() 229 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 230 231 def construct(self, x_): 232 return self.g.log_survival(x_) 233 234def test_log_survival(): 235 """ 236 Test log_survival. 237 """ 238 gamma_benchmark = stats.gamma(np.array([3.0])) 239 expect_log_survival = gamma_benchmark.logsf([2.0]).astype(np.float32) 240 log_survival = LogSF() 241 output = log_survival(Tensor([2.0], dtype=dtype.float32)) 242 tol = 2e-5 243 assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() 244 245class EntropyH(nn.Cell): 246 """ 247 Test class: entropy of Gamma distribution. 248 """ 249 def __init__(self): 250 super(EntropyH, self).__init__() 251 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 252 253 def construct(self): 254 return self.g.entropy() 255 256def test_entropy(): 257 """ 258 Test entropy. 259 """ 260 gamma_benchmark = stats.gamma(np.array([3.0])) 261 expect_entropy = gamma_benchmark.entropy().astype(np.float32) 262 entropy = EntropyH() 263 output = entropy() 264 tol = 1e-6 265 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 266 267class CrossEntropy(nn.Cell): 268 """ 269 Test class: cross entropy between Gamma distributions. 270 """ 271 def __init__(self): 272 super(CrossEntropy, self).__init__() 273 self.g = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 274 275 def construct(self, x_, y_): 276 entropy = self.g.entropy() 277 kl_loss = self.g.kl_loss('Gamma', x_, y_) 278 h_sum_kl = entropy + kl_loss 279 cross_entropy = self.g.cross_entropy('Gamma', x_, y_) 280 return h_sum_kl - cross_entropy 281 282def test_cross_entropy(): 283 """ 284 Test cross_entropy. 285 """ 286 cross_entropy = CrossEntropy() 287 concentration = Tensor([3.0], dtype=dtype.float32) 288 rate = Tensor([2.0], dtype=dtype.float32) 289 diff = cross_entropy(concentration, rate) 290 tol = 1e-6 291 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 292 293class Net(nn.Cell): 294 """ 295 Test class: expand single distribution instance to multiple graphs 296 by specifying the attributes. 297 """ 298 299 def __init__(self): 300 super(Net, self).__init__() 301 self.get_flags = msd.Gamma(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 302 303 def construct(self, x_, y_): 304 kl = self.g.kl_loss('Gamma', x_, y_) 305 prob = self.g.prob(kl) 306 return prob 307 308def test_multiple_graphs(): 309 """ 310 Test multiple graphs case. 311 """ 312 prob = Net() 313 concentration_a = np.array([3.0]).astype(np.float32) 314 rate_a = np.array([1.0]).astype(np.float32) 315 concentration_b = np.array([2.0]).astype(np.float32) 316 rate_b = np.array([1.0]).astype(np.float32) 317 ans = prob(Tensor(concentration_b), Tensor(rate_b)) 318 319 expect_kl_loss = (concentration_a - concentration_b) * special.digamma(concentration_a) \ 320 + special.gammaln(concentration_b) - special.gammaln(concentration_a) \ 321 + concentration_b * np.log(rate_a) - concentration_b * np.log(rate_b) \ 322 + concentration_a * (rate_b / rate_a - 1.) 323 324 gamma_benchmark = stats.gamma(np.array([3.0])) 325 expect_prob = gamma_benchmark.pdf(expect_kl_loss).astype(np.float32) 326 327 tol = 1e-6 328 assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() 329