1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Beta distribution""" 16import numpy as np 17from scipy import stats 18from scipy import special 19import mindspore.context as context 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import Tensor 23from mindspore import dtype 24 25context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 26 27class Prob(nn.Cell): 28 """ 29 Test class: probability of Beta distribution. 30 """ 31 def __init__(self): 32 super(Prob, self).__init__() 33 self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 34 35 def construct(self, x_): 36 return self.b.prob(x_) 37 38def test_pdf(): 39 """ 40 Test pdf. 41 """ 42 beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) 43 expect_pdf = beta_benchmark.pdf([0.25, 0.75]).astype(np.float32) 44 pdf = Prob() 45 output = pdf(Tensor([0.25, 0.75], dtype=dtype.float32)) 46 tol = 1e-6 47 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 48 49class LogProb(nn.Cell): 50 """ 51 Test class: log probability of Beta distribution. 52 """ 53 def __init__(self): 54 super(LogProb, self).__init__() 55 self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 56 57 def construct(self, x_): 58 return self.b.log_prob(x_) 59 60def test_log_likelihood(): 61 """ 62 Test log_pdf. 63 """ 64 beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) 65 expect_logpdf = beta_benchmark.logpdf([0.25, 0.75]).astype(np.float32) 66 logprob = LogProb() 67 output = logprob(Tensor([0.25, 0.75], dtype=dtype.float32)) 68 tol = 1e-6 69 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 70 71class KL(nn.Cell): 72 """ 73 Test class: kl_loss of Beta distribution. 74 """ 75 def __init__(self): 76 super(KL, self).__init__() 77 self.b = msd.Beta(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 78 79 def construct(self, x_, y_): 80 return self.b.kl_loss('Beta', x_, y_) 81 82def test_kl_loss(): 83 """ 84 Test kl_loss. 85 """ 86 concentration1_a = np.array([3.0]).astype(np.float32) 87 concentration0_a = np.array([4.0]).astype(np.float32) 88 89 concentration1_b = np.array([1.0]).astype(np.float32) 90 concentration0_b = np.array([1.0]).astype(np.float32) 91 92 total_concentration_a = concentration1_a + concentration0_a 93 total_concentration_b = concentration1_b + concentration0_b 94 log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) 95 log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) 96 expect_kl_loss = (log_normalization_b - log_normalization_a) \ 97 - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ 98 - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ 99 + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) 100 101 kl_loss = KL() 102 concentration1 = Tensor(concentration1_b, dtype=dtype.float32) 103 concentration0 = Tensor(concentration0_b, dtype=dtype.float32) 104 output = kl_loss(concentration1, concentration0) 105 tol = 1e-6 106 assert (np.abs(output.asnumpy() - expect_kl_loss) < tol).all() 107 108class Basics(nn.Cell): 109 """ 110 Test class: mean/sd/mode of Beta distribution. 111 """ 112 def __init__(self): 113 super(Basics, self).__init__() 114 self.b = msd.Beta(np.array([3.0]), np.array([3.0]), dtype=dtype.float32) 115 116 def construct(self): 117 return self.b.mean(), self.b.sd(), self.b.mode() 118 119def test_basics(): 120 """ 121 Test mean/standard deviation/mode. 122 """ 123 basics = Basics() 124 mean, sd, mode = basics() 125 beta_benchmark = stats.beta(np.array([3.0]), np.array([3.0])) 126 expect_mean = beta_benchmark.mean().astype(np.float32) 127 expect_sd = beta_benchmark.std().astype(np.float32) 128 expect_mode = [0.5] 129 tol = 1e-6 130 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 131 assert (np.abs(mode.asnumpy() - expect_mode) < tol).all() 132 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 133 134class Sampling(nn.Cell): 135 """ 136 Test class: sample of Beta distribution. 137 """ 138 def __init__(self, shape, seed=0): 139 super(Sampling, self).__init__() 140 self.b = msd.Beta(np.array([3.0]), np.array([1.0]), seed=seed, dtype=dtype.float32) 141 self.shape = shape 142 143 def construct(self, concentration1=None, concentration0=None): 144 return self.b.sample(self.shape, concentration1, concentration0) 145 146def test_sample(): 147 """ 148 Test sample. 149 """ 150 shape = (2, 3) 151 seed = 10 152 concentration1 = Tensor([2.0], dtype=dtype.float32) 153 concentration0 = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) 154 sample = Sampling(shape, seed=seed) 155 output = sample(concentration1, concentration0) 156 assert output.shape == (2, 3, 3) 157 158class EntropyH(nn.Cell): 159 """ 160 Test class: entropy of Beta distribution. 161 """ 162 def __init__(self): 163 super(EntropyH, self).__init__() 164 self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 165 166 def construct(self): 167 return self.b.entropy() 168 169def test_entropy(): 170 """ 171 Test entropy. 172 """ 173 beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) 174 expect_entropy = beta_benchmark.entropy().astype(np.float32) 175 entropy = EntropyH() 176 output = entropy() 177 tol = 1e-6 178 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 179 180class CrossEntropy(nn.Cell): 181 """ 182 Test class: cross entropy between Beta distributions. 183 """ 184 def __init__(self): 185 super(CrossEntropy, self).__init__() 186 self.b = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 187 188 def construct(self, x_, y_): 189 entropy = self.b.entropy() 190 kl_loss = self.b.kl_loss('Beta', x_, y_) 191 h_sum_kl = entropy + kl_loss 192 cross_entropy = self.b.cross_entropy('Beta', x_, y_) 193 return h_sum_kl - cross_entropy 194 195def test_cross_entropy(): 196 """ 197 Test cross_entropy. 198 """ 199 cross_entropy = CrossEntropy() 200 concentration1 = Tensor([3.0], dtype=dtype.float32) 201 concentration0 = Tensor([2.0], dtype=dtype.float32) 202 diff = cross_entropy(concentration1, concentration0) 203 tol = 1e-6 204 assert (np.abs(diff.asnumpy() - np.zeros(diff.shape)) < tol).all() 205 206class Net(nn.Cell): 207 """ 208 Test class: expand single distribution instance to multiple graphs 209 by specifying the attributes. 210 """ 211 212 def __init__(self): 213 super(Net, self).__init__() 214 self.beta = msd.Beta(np.array([3.0]), np.array([1.0]), dtype=dtype.float32) 215 216 def construct(self, x_, y_): 217 kl = self.beta.kl_loss('Beta', x_, y_) 218 prob = self.beta.prob(kl) 219 return prob 220 221def test_multiple_graphs(): 222 """ 223 Test multiple graphs case. 224 """ 225 prob = Net() 226 concentration1_a = np.array([3.0]).astype(np.float32) 227 concentration0_a = np.array([1.0]).astype(np.float32) 228 concentration1_b = np.array([2.0]).astype(np.float32) 229 concentration0_b = np.array([1.0]).astype(np.float32) 230 ans = prob(Tensor(concentration1_b), Tensor(concentration0_b)) 231 232 total_concentration_a = concentration1_a + concentration0_a 233 total_concentration_b = concentration1_b + concentration0_b 234 log_normalization_a = np.log(special.beta(concentration1_a, concentration0_a)) 235 log_normalization_b = np.log(special.beta(concentration1_b, concentration0_b)) 236 expect_kl_loss = (log_normalization_b - log_normalization_a) \ 237 - (special.digamma(concentration1_a) * (concentration1_b - concentration1_a)) \ 238 - (special.digamma(concentration0_a) * (concentration0_b - concentration0_a)) \ 239 + (special.digamma(total_concentration_a) * (total_concentration_b - total_concentration_a)) 240 241 beta_benchmark = stats.beta(np.array([3.0]), np.array([1.0])) 242 expect_prob = beta_benchmark.pdf(expect_kl_loss).astype(np.float32) 243 244 tol = 1e-6 245 assert (np.abs(ans.asnumpy() - expect_prob) < tol).all() 246