1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Uniform distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of Uniform distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.u.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) 42 expect_pdf = uniform_benchmark.pdf([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32) 43 pdf = Prob() 44 x_ = Tensor(np.array([-1.0, 0.0, 0.5, 1.0, 1.5, 3.0]).astype(np.float32), dtype=dtype.float32) 45 output = pdf(x_) 46 tol = 1e-6 47 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 48 49class LogProb(nn.Cell): 50 """ 51 Test class: log probability of Uniform distribution. 52 """ 53 def __init__(self): 54 super(LogProb, self).__init__() 55 self.u = msd.Uniform([0.0], [[1.0], [2.0]], dtype=dtype.float32) 56 57 def construct(self, x_): 58 return self.u.log_prob(x_) 59 60def test_log_likelihood(): 61 """ 62 Test log_pdf. 63 """ 64 uniform_benchmark = stats.uniform([0.0], [[1.0], [2.0]]) 65 expect_logpdf = uniform_benchmark.logpdf([0.5]).astype(np.float32) 66 logprob = LogProb() 67 x_ = Tensor(np.array([0.5]).astype(np.float32), dtype=dtype.float32) 68 output = logprob(x_) 69 tol = 1e-6 70 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 71 72class KL(nn.Cell): 73 """ 74 Test class: kl_loss between Uniform distributions. 75 """ 76 def __init__(self): 77 super(KL, self).__init__() 78 self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) 79 80 def construct(self, x_, y_): 81 return self.u.kl_loss('Uniform', x_, y_) 82 83def test_kl_loss(): 84 """ 85 Test kl_loss. 86 """ 87 low_a = 0.0 88 high_a = 1.5 89 low_b = -1.0 90 high_b = 2.0 91 expect_kl_loss = np.log(high_b - low_b) - np.log(high_a - low_a) 92 kl = KL() 93 output = kl(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) 94 tol = 1e-6 95 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 96 97class Basics(nn.Cell): 98 """ 99 Test class: mean/sd of Uniform distribution. 100 """ 101 def __init__(self): 102 super(Basics, self).__init__() 103 self.u = msd.Uniform([0.0], [3.0], dtype=dtype.float32) 104 105 def construct(self): 106 return self.u.mean(), self.u.sd() 107 108def test_basics(): 109 """ 110 Test mean/standard deviation. 111 """ 112 basics = Basics() 113 mean, sd = basics() 114 expect_mean = [1.5] 115 expect_sd = np.sqrt([0.75]) 116 tol = 1e-6 117 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 118 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 119 120class Sampling(nn.Cell): 121 """ 122 Test class: sample of Uniform distribution. 123 """ 124 def __init__(self, shape, seed=0): 125 super(Sampling, self).__init__() 126 self.u = msd.Uniform([0.0], [[1.0], [2.0]], seed=seed, dtype=dtype.float32) 127 self.shape = shape 128 129 def construct(self, low=None, high=None): 130 return self.u.sample(self.shape, low, high) 131 132def test_sample(): 133 """ 134 Test sample. 135 """ 136 shape = (2, 3) 137 seed = 10 138 low = Tensor([1.0], dtype=dtype.float32) 139 high = Tensor([2.0, 3.0, 4.0], dtype=dtype.float32) 140 sample = Sampling(shape, seed=seed) 141 output = sample(low, high) 142 assert output.shape == (2, 3, 3) 143 144class CDF(nn.Cell): 145 """ 146 Test class: cdf of Uniform distribution. 147 """ 148 def __init__(self): 149 super(CDF, self).__init__() 150 self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) 151 152 def construct(self, x_): 153 return self.u.cdf(x_) 154 155def test_cdf(): 156 """ 157 Test cdf. 158 """ 159 uniform_benchmark = stats.uniform([0.0], [1.0]) 160 expect_cdf = uniform_benchmark.cdf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) 161 cdf = CDF() 162 x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) 163 output = cdf(x_) 164 tol = 1e-6 165 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 166 167class LogCDF(nn.Cell): 168 """ 169 Test class: log_cdf of Uniform distribution. 170 """ 171 def __init__(self): 172 super(LogCDF, self).__init__() 173 self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) 174 175 def construct(self, x_): 176 return self.u.log_cdf(x_) 177 178class SF(nn.Cell): 179 """ 180 Test class: survival function of Uniform distribution. 181 """ 182 def __init__(self): 183 super(SF, self).__init__() 184 self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) 185 186 def construct(self, x_): 187 return self.u.survival_function(x_) 188 189class LogSF(nn.Cell): 190 """ 191 Test class: log survival function of Uniform distribution. 192 """ 193 def __init__(self): 194 super(LogSF, self).__init__() 195 self.u = msd.Uniform([0.0], [1.0], dtype=dtype.float32) 196 197 def construct(self, x_): 198 return self.u.log_survival(x_) 199 200class EntropyH(nn.Cell): 201 """ 202 Test class: entropy of Uniform distribution. 203 """ 204 def __init__(self): 205 super(EntropyH, self).__init__() 206 self.u = msd.Uniform([0.0], [1.0, 2.0], dtype=dtype.float32) 207 208 def construct(self): 209 return self.u.entropy() 210 211def test_entropy(): 212 """ 213 Test entropy. 214 """ 215 uniform_benchmark = stats.uniform([0.0], [1.0, 2.0]) 216 expect_entropy = uniform_benchmark.entropy().astype(np.float32) 217 entropy = EntropyH() 218 output = entropy() 219 tol = 1e-6 220 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 221 222class CrossEntropy(nn.Cell): 223 """ 224 Test class: cross_entropy between Uniform distributions. 225 """ 226 def __init__(self): 227 super(CrossEntropy, self).__init__() 228 self.u = msd.Uniform([0.0], [1.5], dtype=dtype.float32) 229 230 def construct(self, x_, y_): 231 entropy = self.u.entropy() 232 kl_loss = self.u.kl_loss('Uniform', x_, y_) 233 h_sum_kl = entropy + kl_loss 234 cross_entropy = self.u.cross_entropy('Uniform', x_, y_) 235 return h_sum_kl - cross_entropy 236 237def test_log_cdf(): 238 """ 239 Test log_cdf. 240 """ 241 uniform_benchmark = stats.uniform([0.0], [1.0]) 242 expect_logcdf = uniform_benchmark.logcdf([0.5, 0.8, 2.0]).astype(np.float32) 243 logcdf = LogCDF() 244 x_ = Tensor(np.array([0.5, 0.8, 2.0]).astype(np.float32), dtype=dtype.float32) 245 output = logcdf(x_) 246 tol = 1e-6 247 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 248 249def test_survival(): 250 """ 251 Test survival function. 252 """ 253 uniform_benchmark = stats.uniform([0.0], [1.0]) 254 expect_survival = uniform_benchmark.sf([-1.0, 0.5, 1.0, 2.0]).astype(np.float32) 255 survival = SF() 256 x_ = Tensor(np.array([-1.0, 0.5, 1.0, 2.0]).astype(np.float32), dtype=dtype.float32) 257 output = survival(x_) 258 tol = 1e-6 259 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 260 261def test_log_survival(): 262 """ 263 Test log survival function. 264 """ 265 uniform_benchmark = stats.uniform([0.0], [1.0]) 266 expect_logsurvival = uniform_benchmark.logsf([0.5, 0.8, -2.0]).astype(np.float32) 267 logsurvival = LogSF() 268 x_ = Tensor(np.array([0.5, 0.8, -2.0]).astype(np.float32), dtype=dtype.float32) 269 output = logsurvival(x_) 270 tol = 1e-6 271 assert (np.abs(output.asnumpy() - expect_logsurvival) < tol).all() 272 273def test_cross_entropy(): 274 """ 275 Test cross_entropy. 276 """ 277 cross_entropy = CrossEntropy() 278 low_b = -1.0 279 high_b = 2.0 280 diff = cross_entropy(Tensor(low_b, dtype=dtype.float32), Tensor(high_b, dtype=dtype.float32)) 281 tol = 1e-6 282 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 283