1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for LogNormal distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of LogNormal distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.ln.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 42 expect_pdf = lognorm_benchmark.pdf([1.0, 2.0]).astype(np.float32) 43 pdf = Prob() 44 output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 45 tol = 1e-6 46 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 47 48class LogProb(nn.Cell): 49 """ 50 Test class: log probability of LogNormal distribution. 51 """ 52 def __init__(self): 53 super(LogProb, self).__init__() 54 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 55 56 def construct(self, x_): 57 return self.ln.log_prob(x_) 58 59def test_log_likelihood(): 60 """ 61 Test log_pdf. 62 """ 63 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 64 expect_logpdf = lognorm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) 65 logprob = LogProb() 66 output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) 67 tol = 1e-6 68 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 69 70class KL(nn.Cell): 71 """ 72 Test class: kl_loss of LogNormal distribution. 73 """ 74 def __init__(self): 75 super(KL, self).__init__() 76 self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32) 77 78 def construct(self, x_, y_): 79 return self.ln.kl_loss('LogNormal', x_, y_) 80 81def test_kl_loss(): 82 """ 83 Test kl_loss. 84 """ 85 mean_a = np.array([0.3]).astype(np.float32) 86 sd_a = np.array([0.4]).astype(np.float32) 87 88 mean_b = np.array([1.0]).astype(np.float32) 89 sd_b = np.array([1.0]).astype(np.float32) 90 91 diff_log_scale = np.log(sd_a) - np.log(sd_b) 92 squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) 93 expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale 94 95 kl_loss = KL() 96 mean = Tensor(mean_b, dtype=dtype.float32) 97 sd = Tensor(sd_b, dtype=dtype.float32) 98 output = kl_loss(mean, sd) 99 tol = 1e-6 100 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 101 102class Basics(nn.Cell): 103 """ 104 Test class: mean/sd/mode of LogNormal distribution. 105 """ 106 def __init__(self): 107 super(Basics, self).__init__() 108 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 109 110 def construct(self): 111 return self.ln.mean(), self.ln.sd(), self.ln.mode() 112 113def test_basics(): 114 """ 115 Test mean/standard deviation/mode. 116 """ 117 basics = Basics() 118 mean, sd, mode = basics() 119 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 120 expect_mean = lognorm_benchmark.mean().astype(np.float32) 121 expect_sd = lognorm_benchmark.std().astype(np.float32) 122 expect_mode = (lognorm_benchmark.median() / np.exp(np.square([[0.2], [0.4]]))).astype(np.float32) 123 tol = 1e-6 124 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 125 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 126 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 127 128class Sampling(nn.Cell): 129 """ 130 Test class: sample of LogNormal distribution. 131 """ 132 def __init__(self, shape, seed=0): 133 super(Sampling, self).__init__() 134 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), seed=seed, dtype=dtype.float32) 135 self.shape = shape 136 137 def construct(self, mean=None, sd=None): 138 return self.ln.sample(self.shape, mean, sd) 139 140def test_sample(): 141 """ 142 Test sample. 143 """ 144 shape = (2, 3) 145 seed = 10 146 mean = Tensor([2.0], dtype=dtype.float32) 147 sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) 148 sample = Sampling(shape, seed=seed) 149 output = sample(mean, sd) 150 assert output.shape == (2, 3, 3) 151 152class CDF(nn.Cell): 153 """ 154 Test class: cdf of LogNormal distribution. 155 """ 156 def __init__(self): 157 super(CDF, self).__init__() 158 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 159 160 def construct(self, x_): 161 return self.ln.cdf(x_) 162 163def test_cdf(): 164 """ 165 Test cdf. 166 """ 167 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 168 expect_cdf = lognorm_benchmark.cdf([1.0, 2.0]).astype(np.float32) 169 cdf = CDF() 170 output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 171 tol = 2e-5 172 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 173 174class LogCDF(nn.Cell): 175 """ 176 Test class: log_cdf of Mormal distribution. 177 """ 178 def __init__(self): 179 super(LogCDF, self).__init__() 180 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 181 182 def construct(self, x_): 183 return self.ln.log_cdf(x_) 184 185def test_log_cdf(): 186 """ 187 Test log cdf. 188 """ 189 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 190 expect_logcdf = lognorm_benchmark.logcdf([1.0, 2.0]).astype(np.float32) 191 logcdf = LogCDF() 192 output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 193 tol = 1e-4 194 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 195 196class SF(nn.Cell): 197 """ 198 Test class: survival function of LogNormal distribution. 199 """ 200 def __init__(self): 201 super(SF, self).__init__() 202 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 203 204 def construct(self, x_): 205 return self.ln.survival_function(x_) 206 207def test_survival(): 208 """ 209 Test log_survival. 210 """ 211 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 212 expect_survival = lognorm_benchmark.sf([1.0, 2.0]).astype(np.float32) 213 survival_function = SF() 214 output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) 215 tol = 2e-5 216 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 217 218class LogSF(nn.Cell): 219 """ 220 Test class: log survival function of LogNormal distribution. 221 """ 222 def __init__(self): 223 super(LogSF, self).__init__() 224 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 225 226 def construct(self, x_): 227 return self.ln.log_survival(x_) 228 229def test_log_survival(): 230 """ 231 Test log_survival. 232 """ 233 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 234 expect_log_survival = lognorm_benchmark.logsf([1.0, 2.0]).astype(np.float32) 235 log_survival = LogSF() 236 output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) 237 tol = 5e-4 238 assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() 239 240class EntropyH(nn.Cell): 241 """ 242 Test class: entropy of LogNormal distribution. 243 """ 244 def __init__(self): 245 super(EntropyH, self).__init__() 246 self.ln = msd.LogNormal(np.array([0.3]), np.array([[0.2], [0.4]]), dtype=dtype.float32) 247 248 def construct(self): 249 return self.ln.entropy() 250 251def test_entropy(): 252 """ 253 Test entropy. 254 """ 255 lognorm_benchmark = stats.lognorm(s=np.array([[0.2], [0.4]]), scale=np.exp(np.array([0.3]))) 256 expect_entropy = lognorm_benchmark.entropy().astype(np.float32) 257 entropy = EntropyH() 258 output = entropy() 259 tol = 1e-6 260 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 261 262class CrossEntropy(nn.Cell): 263 """ 264 Test class: cross entropy between LogNormal distributions. 265 """ 266 def __init__(self): 267 super(CrossEntropy, self).__init__() 268 self.ln = msd.LogNormal(np.array([0.3]), np.array([0.4]), dtype=dtype.float32) 269 270 def construct(self, x_, y_): 271 entropy = self.ln.entropy() 272 kl_loss = self.ln.kl_loss('LogNormal', x_, y_) 273 h_sum_kl = entropy + kl_loss 274 cross_entropy = self.ln.cross_entropy('LogNormal', x_, y_) 275 return h_sum_kl - cross_entropy 276 277def test_cross_entropy(): 278 """ 279 Test cross_entropy. 280 """ 281 cross_entropy = CrossEntropy() 282 mean = Tensor([1.0], dtype=dtype.float32) 283 sd = Tensor([1.0], dtype=dtype.float32) 284 diff = cross_entropy(mean, sd) 285 tol = 1e-6 286 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 287 288class Net(nn.Cell): 289 """ 290 Test class: expand single distribution instance to multiple graphs 291 by specifying the attributes. 292 """ 293 294 def __init__(self): 295 super(Net, self).__init__() 296 self.LogNormal = msd.LogNormal(0., 1., dtype=dtype.float32) 297 298 def construct(self, x_, y_): 299 kl = self.LogNormal.kl_loss('LogNormal', x_, y_) 300 prob = self.LogNormal.prob(kl) 301 return prob 302 303def test_multiple_graphs(): 304 """ 305 Test multiple graphs case. 306 """ 307 prob = Net() 308 mean_a = np.array([0.0]).astype(np.float32) 309 sd_a = np.array([1.0]).astype(np.float32) 310 mean_b = np.array([1.0]).astype(np.float32) 311 sd_b = np.array([1.0]).astype(np.float32) 312 ans = prob(Tensor(mean_b), Tensor(sd_b)) 313 314 diff_log_scale = np.log(sd_a) - np.log(sd_b) 315 squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) 316 expect_kl_loss = 0.5 * squared_diff + 0.5 * \ 317 np.expm1(2 * diff_log_scale) - diff_log_scale 318 lognorm_benchmark = stats.lognorm(s=np.array([1.]), scale=np.exp(np.array([0.]))) 319 expect_prob = lognorm_benchmark.pdf(expect_kl_loss).astype(np.float32) 320 321 tol = 1e-6 322 assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() 323