1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Gumbel distribution""" 16import numpy as np 17from scipy import stats 18from scipy import special 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import Tensor 23from mindspore import dtype 24 25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 26 27class Prob(nn.Cell): 28 """ 29 Test class: probability of Gumbel distribution. 30 """ 31 def __init__(self): 32 super(Prob, self).__init__() 33 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 34 35 def construct(self, x_): 36 return self.gum.prob(x_) 37 38def test_pdf(): 39 """ 40 Test pdf. 41 """ 42 loc = np.array([0.0]).astype(np.float32) 43 scale = np.array([[1.0], [2.0]]).astype(np.float32) 44 gumbel_benchmark = stats.gumbel_r(loc, scale) 45 value = np.array([1.0, 2.0]).astype(np.float32) 46 expect_pdf = gumbel_benchmark.pdf(value).astype(np.float32) 47 pdf = Prob() 48 output = pdf(Tensor(value, dtype=dtype.float32)) 49 tol = 1e-6 50 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 51 52class LogProb(nn.Cell): 53 """ 54 Test class: log probability of Gumbel distribution. 55 """ 56 def __init__(self): 57 super(LogProb, self).__init__() 58 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 59 60 def construct(self, x_): 61 return self.gum.log_prob(x_) 62 63def test_log_likelihood(): 64 """ 65 Test log_pdf. 66 """ 67 loc = np.array([0.0]).astype(np.float32) 68 scale = np.array([[1.0], [2.0]]).astype(np.float32) 69 gumbel_benchmark = stats.gumbel_r(loc, scale) 70 expect_logpdf = gumbel_benchmark.logpdf([1.0, 2.0]).astype(np.float32) 71 logprob = LogProb() 72 output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) 73 tol = 1e-6 74 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 75 76class KL(nn.Cell): 77 """ 78 Test class: kl_loss of Gumbel distribution. 79 """ 80 def __init__(self): 81 super(KL, self).__init__() 82 self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0]), dtype=dtype.float32) 83 84 def construct(self, loc_b, scale_b): 85 return self.gum.kl_loss('Gumbel', loc_b, scale_b) 86 87def test_kl_loss(): 88 """ 89 Test kl_loss. 90 """ 91 loc = np.array([0.0]).astype(np.float32) 92 scale = np.array([1.0, 2.0]).astype(np.float32) 93 94 loc_b = np.array([1.0]).astype(np.float32) 95 scale_b = np.array([1.0, 2.0]).astype(np.float32) 96 97 expect_kl_loss = np.log(scale_b) - np.log(scale) +\ 98 np.euler_gamma * (scale / scale_b - 1.) +\ 99 np.expm1((loc_b - loc) / scale_b + special.loggamma(scale / scale_b + 1.)) 100 101 kl_loss = KL() 102 loc_b = Tensor(loc_b, dtype=dtype.float32) 103 scale_b = Tensor(scale_b, dtype=dtype.float32) 104 output = kl_loss(loc_b, scale_b) 105 tol = 1e-5 106 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 107 108class Basics(nn.Cell): 109 """ 110 Test class: mean/sd/mode of Gumbel distribution. 111 """ 112 def __init__(self): 113 super(Basics, self).__init__() 114 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 115 116 def construct(self): 117 return self.gum.mean(), self.gum.sd(), self.gum.mode() 118 119def test_basics(): 120 """ 121 Test mean/standard deviation/mode. 122 """ 123 basics = Basics() 124 mean, sd, mode = basics() 125 126 loc = np.array([0.0]).astype(np.float32) 127 scale = np.array([[1.0], [2.0]]).astype(np.float32) 128 gumbel_benchmark = stats.gumbel_r(loc, scale) 129 expect_mean = gumbel_benchmark.mean().astype(np.float32) 130 expect_sd = gumbel_benchmark.std().astype(np.float32) 131 expect_mode = np.array([[0.0], [0.0]]).astype(np.float32) 132 tol = 1e-6 133 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 134 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 135 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 136 137class Sampling(nn.Cell): 138 """ 139 Test class: sample of Gumbel distribution. 140 """ 141 def __init__(self, shape, seed=0): 142 super(Sampling, self).__init__() 143 self.gum = msd.Gumbel(np.array([0.0]), np.array([1.0, 2.0, 3.0]), dtype=dtype.float32, seed=seed) 144 self.shape = shape 145 146 def construct(self): 147 return self.gum.sample(self.shape) 148 149def test_sample(): 150 """ 151 Test sample. 152 """ 153 shape = (2, 3) 154 seed = 10 155 sample = Sampling(shape, seed=seed) 156 output = sample() 157 assert output.shape == (2, 3, 3) 158 159class CDF(nn.Cell): 160 """ 161 Test class: cdf of Gumbel distribution. 162 """ 163 def __init__(self): 164 super(CDF, self).__init__() 165 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 166 167 def construct(self, x_): 168 return self.gum.cdf(x_) 169 170def test_cdf(): 171 """ 172 Test cdf. 173 """ 174 loc = np.array([0.0]).astype(np.float32) 175 scale = np.array([[1.0], [2.0]]).astype(np.float32) 176 gumbel_benchmark = stats.gumbel_r(loc, scale) 177 expect_cdf = gumbel_benchmark.cdf([1.0, 2.0]).astype(np.float32) 178 cdf = CDF() 179 output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 180 tol = 2e-5 181 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 182 183class LogCDF(nn.Cell): 184 """ 185 Test class: log_cdf of Gumbel distribution. 186 """ 187 def __init__(self): 188 super(LogCDF, self).__init__() 189 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 190 191 def construct(self, x_): 192 return self.gum.log_cdf(x_) 193 194def test_log_cdf(): 195 """ 196 Test log cdf. 197 """ 198 loc = np.array([0.0]).astype(np.float32) 199 scale = np.array([[1.0], [2.0]]).astype(np.float32) 200 gumbel_benchmark = stats.gumbel_r(loc, scale) 201 expect_logcdf = gumbel_benchmark.logcdf([1.0, 2.0]).astype(np.float32) 202 logcdf = LogCDF() 203 output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 204 tol = 1e-4 205 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 206 207class SF(nn.Cell): 208 """ 209 Test class: survival function of Gumbel distribution. 210 """ 211 def __init__(self): 212 super(SF, self).__init__() 213 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 214 215 def construct(self, x_): 216 return self.gum.survival_function(x_) 217 218def test_survival(): 219 """ 220 Test log_survival. 221 """ 222 loc = np.array([0.0]).astype(np.float32) 223 scale = np.array([[1.0], [2.0]]).astype(np.float32) 224 gumbel_benchmark = stats.gumbel_r(loc, scale) 225 expect_survival = gumbel_benchmark.sf([1.0, 2.0]).astype(np.float32) 226 survival_function = SF() 227 output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) 228 tol = 2e-5 229 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 230 231class LogSF(nn.Cell): 232 """ 233 Test class: log survival function of Gumbel distribution. 234 """ 235 def __init__(self): 236 super(LogSF, self).__init__() 237 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 238 239 def construct(self, x_): 240 return self.gum.log_survival(x_) 241 242def test_log_survival(): 243 """ 244 Test log_survival. 245 """ 246 loc = np.array([0.0]).astype(np.float32) 247 scale = np.array([[1.0], [2.0]]).astype(np.float32) 248 gumbel_benchmark = stats.gumbel_r(loc, scale) 249 expect_log_survival = gumbel_benchmark.logsf([1.0, 2.0]).astype(np.float32) 250 log_survival = LogSF() 251 output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) 252 tol = 5e-4 253 assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() 254 255class EntropyH(nn.Cell): 256 """ 257 Test class: entropy of Gumbel distribution. 258 """ 259 def __init__(self): 260 super(EntropyH, self).__init__() 261 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 262 263 def construct(self): 264 return self.gum.entropy() 265 266def test_entropy(): 267 """ 268 Test entropy. 269 """ 270 loc = np.array([0.0]).astype(np.float32) 271 scale = np.array([[1.0], [2.0]]).astype(np.float32) 272 gumbel_benchmark = stats.gumbel_r(loc, scale) 273 expect_entropy = gumbel_benchmark.entropy().astype(np.float32) 274 entropy = EntropyH() 275 output = entropy() 276 tol = 1e-6 277 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 278 279class CrossEntropy(nn.Cell): 280 """ 281 Test class: cross entropy between Gumbel distributions. 282 """ 283 def __init__(self): 284 super(CrossEntropy, self).__init__() 285 self.gum = msd.Gumbel(np.array([0.0]), np.array([[1.0], [2.0]]), dtype=dtype.float32) 286 287 def construct(self, x_, y_): 288 entropy = self.gum.entropy() 289 kl_loss = self.gum.kl_loss('Gumbel', x_, y_) 290 h_sum_kl = entropy + kl_loss 291 cross_entropy = self.gum.cross_entropy('Gumbel', x_, y_) 292 return h_sum_kl - cross_entropy 293 294def test_cross_entropy(): 295 """ 296 Test cross_entropy. 297 """ 298 cross_entropy = CrossEntropy() 299 loc = Tensor([1.0], dtype=dtype.float32) 300 scale = Tensor([1.0], dtype=dtype.float32) 301 diff = cross_entropy(loc, scale) 302 tol = 1e-6 303 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 304