1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Normal distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of Normal distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.n.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 42 expect_pdf = norm_benchmark.pdf([1.0, 2.0]).astype(np.float32) 43 pdf = Prob() 44 output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 45 tol = 1e-6 46 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 47 48class LogProb(nn.Cell): 49 """ 50 Test class: log probability of Normal distribution. 51 """ 52 def __init__(self): 53 super(LogProb, self).__init__() 54 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 55 56 def construct(self, x_): 57 return self.n.log_prob(x_) 58 59def test_log_likelihood(): 60 """ 61 Test log_pdf. 62 """ 63 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 64 expect_logpdf = norm_benchmark.logpdf([1.0, 2.0]).astype(np.float32) 65 logprob = LogProb() 66 output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) 67 tol = 1e-6 68 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 69 70 71class KL(nn.Cell): 72 """ 73 Test class: kl_loss of Normal distribution. 74 """ 75 def __init__(self): 76 super(KL, self).__init__() 77 self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 78 79 def construct(self, x_, y_): 80 return self.n.kl_loss('Normal', x_, y_) 81 82 83def test_kl_loss(): 84 """ 85 Test kl_loss. 86 """ 87 mean_a = np.array([3.0]).astype(np.float32) 88 sd_a = np.array([4.0]).astype(np.float32) 89 90 mean_b = np.array([1.0]).astype(np.float32) 91 sd_b = np.array([1.0]).astype(np.float32) 92 93 diff_log_scale = np.log(sd_a) - np.log(sd_b) 94 squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) 95 expect_kl_loss = 0.5 * squared_diff + 0.5 * np.expm1(2 * diff_log_scale) - diff_log_scale 96 97 kl_loss = KL() 98 mean = Tensor(mean_b, dtype=dtype.float32) 99 sd = Tensor(sd_b, dtype=dtype.float32) 100 output = kl_loss(mean, sd) 101 tol = 1e-6 102 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 103 104class Basics(nn.Cell): 105 """ 106 Test class: mean/sd/mode of Normal distribution. 107 """ 108 def __init__(self): 109 super(Basics, self).__init__() 110 self.n = msd.Normal(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) 111 112 def construct(self): 113 return self.n.mean(), self.n.sd(), self.n.mode() 114 115def test_basics(): 116 """ 117 Test mean/standard deviation/mode. 118 """ 119 basics = Basics() 120 mean, sd, mode = basics() 121 expect_mean = [3.0, 3.0] 122 expect_sd = [2.0, 4.0] 123 tol = 1e-6 124 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 125 assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() 126 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 127 128class Sampling(nn.Cell): 129 """ 130 Test class: sample of Normal distribution. 131 """ 132 def __init__(self, shape, seed=0): 133 super(Sampling, self).__init__() 134 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) 135 self.shape = shape 136 137 def construct(self, mean=None, sd=None): 138 return self.n.sample(self.shape, mean, sd) 139 140def test_sample(): 141 """ 142 Test sample. 143 """ 144 shape = (2, 3) 145 seed = 10 146 mean = Tensor([2.0], dtype=dtype.float32) 147 sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) 148 sample = Sampling(shape, seed=seed) 149 output = sample(mean, sd) 150 assert output.shape == (2, 3, 3) 151 152class CDF(nn.Cell): 153 """ 154 Test class: cdf of Normal distribution. 155 """ 156 def __init__(self): 157 super(CDF, self).__init__() 158 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 159 160 def construct(self, x_): 161 return self.n.cdf(x_) 162 163 164def test_cdf(): 165 """ 166 Test cdf. 167 """ 168 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 169 expect_cdf = norm_benchmark.cdf([1.0, 2.0]).astype(np.float32) 170 cdf = CDF() 171 output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 172 tol = 2e-5 173 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 174 175class LogCDF(nn.Cell): 176 """ 177 Test class: log_cdf of Mormal distribution. 178 """ 179 def __init__(self): 180 super(LogCDF, self).__init__() 181 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 182 183 def construct(self, x_): 184 return self.n.log_cdf(x_) 185 186def test_log_cdf(): 187 """ 188 Test log cdf. 189 """ 190 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 191 expect_logcdf = norm_benchmark.logcdf([1.0, 2.0]).astype(np.float32) 192 logcdf = LogCDF() 193 output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 194 tol = 5e-5 195 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 196 197class SF(nn.Cell): 198 """ 199 Test class: survival function of Normal distribution. 200 """ 201 def __init__(self): 202 super(SF, self).__init__() 203 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 204 205 def construct(self, x_): 206 return self.n.survival_function(x_) 207 208def test_survival(): 209 """ 210 Test log_survival. 211 """ 212 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 213 expect_survival = norm_benchmark.sf([1.0, 2.0]).astype(np.float32) 214 survival_function = SF() 215 output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) 216 tol = 2e-5 217 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 218 219class LogSF(nn.Cell): 220 """ 221 Test class: log survival function of Normal distribution. 222 """ 223 def __init__(self): 224 super(LogSF, self).__init__() 225 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 226 227 def construct(self, x_): 228 return self.n.log_survival(x_) 229 230def test_log_survival(): 231 """ 232 Test log_survival. 233 """ 234 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 235 expect_log_survival = norm_benchmark.logsf([1.0, 2.0]).astype(np.float32) 236 log_survival = LogSF() 237 output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) 238 tol = 2e-5 239 assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() 240 241class EntropyH(nn.Cell): 242 """ 243 Test class: entropy of Normal distribution. 244 """ 245 def __init__(self): 246 super(EntropyH, self).__init__() 247 self.n = msd.Normal(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 248 249 def construct(self): 250 return self.n.entropy() 251 252def test_entropy(): 253 """ 254 Test entropy. 255 """ 256 norm_benchmark = stats.norm(np.array([3.0]), np.array([[2.0], [4.0]])) 257 expect_entropy = norm_benchmark.entropy().astype(np.float32) 258 entropy = EntropyH() 259 output = entropy() 260 tol = 1e-6 261 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 262 263class CrossEntropy(nn.Cell): 264 """ 265 Test class: cross entropy between Normal distributions. 266 """ 267 def __init__(self): 268 super(CrossEntropy, self).__init__() 269 self.n = msd.Normal(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 270 271 def construct(self, x_, y_): 272 entropy = self.n.entropy() 273 kl_loss = self.n.kl_loss('Normal', x_, y_) 274 h_sum_kl = entropy + kl_loss 275 cross_entropy = self.n.cross_entropy('Normal', x_, y_) 276 return h_sum_kl - cross_entropy 277 278def test_cross_entropy(): 279 """ 280 Test cross_entropy. 281 """ 282 cross_entropy = CrossEntropy() 283 mean = Tensor([1.0], dtype=dtype.float32) 284 sd = Tensor([1.0], dtype=dtype.float32) 285 diff = cross_entropy(mean, sd) 286 tol = 1e-6 287 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 288 289class Net(nn.Cell): 290 """ 291 Test class: expand single distribution instance to multiple graphs 292 by specifying the attributes. 293 """ 294 295 def __init__(self): 296 super(Net, self).__init__() 297 self.normal = msd.Normal(0., 1., dtype=dtype.float32) 298 299 def construct(self, x_, y_): 300 kl = self.normal.kl_loss('Normal', x_, y_) 301 prob = self.normal.prob(kl) 302 return prob 303 304def test_multiple_graphs(): 305 """ 306 Test multiple graphs case. 307 """ 308 prob = Net() 309 mean_a = np.array([0.0]).astype(np.float32) 310 sd_a = np.array([1.0]).astype(np.float32) 311 mean_b = np.array([1.0]).astype(np.float32) 312 sd_b = np.array([1.0]).astype(np.float32) 313 ans = prob(Tensor(mean_b), Tensor(sd_b)) 314 315 diff_log_scale = np.log(sd_a) - np.log(sd_b) 316 squared_diff = np.square(mean_a / sd_b - mean_b / sd_b) 317 expect_kl_loss = 0.5 * squared_diff + 0.5 * \ 318 np.expm1(2 * diff_log_scale) - diff_log_scale 319 320 norm_benchmark = stats.norm(np.array([0.0]), np.array([1.0])) 321 expect_prob = norm_benchmark.pdf(expect_kl_loss).astype(np.float32) 322 323 tol = 1e-6 324 assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() 325