1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Exponential distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of Exponential distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.e.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 42 expect_pdf = expon_benchmark.pdf([-1.0, 0.0, 1.0]).astype(np.float32) 43 pdf = Prob() 44 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 45 output = pdf(x_) 46 tol = 1e-6 47 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 48 49class LogProb(nn.Cell): 50 """ 51 Test class: log probability of Exponential distribution. 52 """ 53 def __init__(self): 54 super(LogProb, self).__init__() 55 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 56 57 def construct(self, x_): 58 return self.e.log_prob(x_) 59 60def test_log_likelihood(): 61 """ 62 Test log_pdf. 63 """ 64 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 65 expect_logpdf = expon_benchmark.logpdf([0.5, 1.0, 2.0]).astype(np.float32) 66 logprob = LogProb() 67 x_ = Tensor(np.array([0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) 68 output = logprob(x_) 69 tol = 1e-6 70 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 71 72class KL(nn.Cell): 73 """ 74 Test class: kl_loss between Exponential distributions. 75 """ 76 def __init__(self): 77 super(KL, self).__init__() 78 self.e = msd.Exponential([1.5], dtype=dtype.float32) 79 80 def construct(self, x_): 81 return self.e.kl_loss('Exponential', x_) 82 83def test_kl_loss(): 84 """ 85 Test kl_loss. 86 """ 87 rate_a = 1.5 88 rate_b = np.array([0.5, 2.0]).astype(np.float32) 89 expect_kl_loss = np.log(rate_a) - np.log(rate_b) + rate_b / rate_a - 1.0 90 kl = KL() 91 output = kl(Tensor(rate_b, dtype=dtype.float32)) 92 tol = 1e-6 93 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 94 95class Basics(nn.Cell): 96 """ 97 Test class: mean/sd/mode of Exponential distribution. 98 """ 99 def __init__(self): 100 super(Basics, self).__init__() 101 self.e = msd.Exponential([0.5], dtype=dtype.float32) 102 103 def construct(self): 104 return self.e.mean(), self.e.sd(), self.e.mode() 105 106def test_basics(): 107 """ 108 Test mean/standard/mode deviation. 109 """ 110 basics = Basics() 111 mean, sd, mode = basics() 112 expect_mean = 2. 113 expect_sd = 2. 114 expect_mode = 0. 115 tol = 1e-6 116 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 117 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 118 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 119 120class Sampling(nn.Cell): 121 """ 122 Test class: sample of Exponential distribution. 123 """ 124 def __init__(self, shape, seed=0): 125 super(Sampling, self).__init__() 126 self.e = msd.Exponential([[1.0], [0.5]], seed=seed, dtype=dtype.float32) 127 self.shape = shape 128 129 def construct(self, rate=None): 130 return self.e.sample(self.shape, rate) 131 132def test_sample(): 133 """ 134 Test sample. 135 """ 136 shape = (2, 3) 137 seed = 10 138 rate = Tensor([1.0, 2.0, 3.0], dtype=dtype.float32) 139 sample = Sampling(shape, seed=seed) 140 output = sample(rate) 141 assert output.shape == (2, 3, 3) 142 143class CDF(nn.Cell): 144 """ 145 Test class: cdf of Exponential distribution. 146 """ 147 def __init__(self): 148 super(CDF, self).__init__() 149 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 150 151 def construct(self, x_): 152 return self.e.cdf(x_) 153 154def test_cdf(): 155 """ 156 Test cdf. 157 """ 158 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 159 expect_cdf = expon_benchmark.cdf([-1.0, 0.0, 1.0]).astype(np.float32) 160 cdf = CDF() 161 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 162 output = cdf(x_) 163 tol = 1e-6 164 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 165 166class LogCDF(nn.Cell): 167 """ 168 Test class: log_cdf of Exponential distribution. 169 """ 170 def __init__(self): 171 super(LogCDF, self).__init__() 172 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 173 174 def construct(self, x_): 175 return self.e.log_cdf(x_) 176 177def test_log_cdf(): 178 """ 179 Test log_cdf. 180 """ 181 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 182 expect_logcdf = expon_benchmark.logcdf([0.5, 1.0, 2.5]).astype(np.float32) 183 logcdf = LogCDF() 184 x_ = Tensor(np.array([0.5, 1.0, 2.5]).astype(np.float32), dtype=dtype.float32) 185 output = logcdf(x_) 186 tol = 1e-6 187 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 188 189class SF(nn.Cell): 190 """ 191 Test class: survival function of Exponential distribution. 192 """ 193 def __init__(self): 194 super(SF, self).__init__() 195 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 196 197 def construct(self, x_): 198 return self.e.survival_function(x_) 199 200def test_survival(): 201 """ 202 Test survival function. 203 """ 204 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 205 expect_survival = expon_benchmark.sf([-1.0, 0.0, 1.0]).astype(np.float32) 206 survival = SF() 207 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 208 output = survival(x_) 209 tol = 1e-6 210 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 211 212class LogSF(nn.Cell): 213 """ 214 Test class: log survival function of Exponential distribution. 215 """ 216 def __init__(self): 217 super(LogSF, self).__init__() 218 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 219 220 def construct(self, x_): 221 return self.e.log_survival(x_) 222 223def test_log_survival(): 224 """ 225 Test log survival function. 226 """ 227 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 228 expect_logsurvival = expon_benchmark.logsf([-1.0, 0.0, 1.0]).astype(np.float32) 229 logsurvival = LogSF() 230 x_ = Tensor(np.array([-1.0, 0.0, 1.0]).astype(np.float32), dtype=dtype.float32) 231 output = logsurvival(x_) 232 tol = 1e-6 233 assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() 234 235class EntropyH(nn.Cell): 236 """ 237 Test class: entropy of Exponential distribution. 238 """ 239 def __init__(self): 240 super(EntropyH, self).__init__() 241 self.e = msd.Exponential([[1.0], [0.5]], dtype=dtype.float32) 242 243 def construct(self): 244 return self.e.entropy() 245 246def test_entropy(): 247 """ 248 Test entropy. 249 """ 250 expon_benchmark = stats.expon(scale=[[1.0], [2.0]]) 251 expect_entropy = expon_benchmark.entropy().astype(np.float32) 252 entropy = EntropyH() 253 output = entropy() 254 tol = 1e-6 255 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 256 257class CrossEntropy(nn.Cell): 258 """ 259 Test class: cross entropy between Exponential distribution. 260 """ 261 def __init__(self): 262 super(CrossEntropy, self).__init__() 263 self.e = msd.Exponential([1.0], dtype=dtype.float32) 264 265 def construct(self, x_): 266 entropy = self.e.entropy() 267 kl_loss = self.e.kl_loss('Exponential', x_) 268 h_sum_kl = entropy + kl_loss 269 cross_entropy = self.e.cross_entropy('Exponential', x_) 270 return h_sum_kl - cross_entropy 271 272def test_cross_entropy(): 273 """ 274 Test cross_entropy. 275 """ 276 cross_entropy = CrossEntropy() 277 rate = Tensor([0.5], dtype=dtype.float32) 278 diff = cross_entropy(rate) 279 tol = 1e-6 280 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 281