1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.Gamma. 17""" 18import numpy as np 19import pytest 20 21import mindspore.nn as nn 22import mindspore.nn.probability.distribution as msd 23from mindspore import dtype 24from mindspore import Tensor 25 26def test_gamma_shape_errpr(): 27 """ 28 Invalid shapes. 29 """ 30 with pytest.raises(ValueError): 31 msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) 32 33def test_type(): 34 with pytest.raises(TypeError): 35 msd.Gamma([0.], [1.], dtype=dtype.int32) 36 37def test_name(): 38 with pytest.raises(TypeError): 39 msd.Gamma([0.], [1.], name=1.0) 40 41def test_seed(): 42 with pytest.raises(TypeError): 43 msd.Gamma([0.], [1.], seed='seed') 44 45def test_rate(): 46 with pytest.raises(ValueError): 47 msd.Gamma([0.], [0.]) 48 with pytest.raises(ValueError): 49 msd.Gamma([0.], [-1.]) 50 51def test_scalar(): 52 with pytest.raises(TypeError): 53 msd.Gamma(3., [4.]) 54 with pytest.raises(TypeError): 55 msd.Gamma([3.], -4.) 56 57def test_arguments(): 58 """ 59 args passing during initialization. 60 """ 61 g = msd.Gamma() 62 assert isinstance(g, msd.Distribution) 63 g = msd.Gamma([3.0], [4.0], dtype=dtype.float32) 64 assert isinstance(g, msd.Distribution) 65 66 67class GammaProb(nn.Cell): 68 """ 69 Gamma distribution: initialize with concentration/rate. 70 """ 71 def __init__(self): 72 super(GammaProb, self).__init__() 73 self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32) 74 75 def construct(self, value): 76 prob = self.gamma.prob(value) 77 log_prob = self.gamma.log_prob(value) 78 cdf = self.gamma.cdf(value) 79 log_cdf = self.gamma.log_cdf(value) 80 sf = self.gamma.survival_function(value) 81 log_sf = self.gamma.log_survival(value) 82 return prob + log_prob + cdf + log_cdf + sf + log_sf 83 84def test_gamma_prob(): 85 """ 86 Test probability functions: passing value through construct. 87 """ 88 net = GammaProb() 89 value = Tensor([0.5, 1.0], dtype=dtype.float32) 90 ans = net(value) 91 assert isinstance(ans, Tensor) 92 93 94class GammaProb1(nn.Cell): 95 """ 96 Gamma distribution: initialize without concentration/rate. 97 """ 98 def __init__(self): 99 super(GammaProb1, self).__init__() 100 self.gamma = msd.Gamma() 101 102 def construct(self, value, concentration, rate): 103 prob = self.gamma.prob(value, concentration, rate) 104 log_prob = self.gamma.log_prob(value, concentration, rate) 105 cdf = self.gamma.cdf(value, concentration, rate) 106 log_cdf = self.gamma.log_cdf(value, concentration, rate) 107 sf = self.gamma.survival_function(value, concentration, rate) 108 log_sf = self.gamma.log_survival(value, concentration, rate) 109 return prob + log_prob + cdf + log_cdf + sf + log_sf 110 111def test_gamma_prob1(): 112 """ 113 Test probability functions: passing concentration/rate, value through construct. 114 """ 115 net = GammaProb1() 116 value = Tensor([0.5, 1.0], dtype=dtype.float32) 117 concentration = Tensor([2.0, 3.0], dtype=dtype.float32) 118 rate = Tensor([1.0], dtype=dtype.float32) 119 ans = net(value, concentration, rate) 120 assert isinstance(ans, Tensor) 121 122class GammaKl(nn.Cell): 123 """ 124 Test class: kl_loss of Gamma distribution. 125 """ 126 def __init__(self): 127 super(GammaKl, self).__init__() 128 self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 129 self.g2 = msd.Gamma(dtype=dtype.float32) 130 131 def construct(self, concentration_b, rate_b, concentration_a, rate_a): 132 kl1 = self.g1.kl_loss('Gamma', concentration_b, rate_b) 133 kl2 = self.g2.kl_loss('Gamma', concentration_b, rate_b, concentration_a, rate_a) 134 return kl1 + kl2 135 136def test_kl(): 137 """ 138 Test kl_loss. 139 """ 140 net = GammaKl() 141 concentration_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 142 rate_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 143 concentration_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) 144 rate_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) 145 ans = net(concentration_b, rate_b, concentration_a, rate_a) 146 assert isinstance(ans, Tensor) 147 148class GammaCrossEntropy(nn.Cell): 149 """ 150 Test class: cross_entropy of Gamma distribution. 151 """ 152 def __init__(self): 153 super(GammaCrossEntropy, self).__init__() 154 self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 155 self.g2 = msd.Gamma(dtype=dtype.float32) 156 157 def construct(self, concentration_b, rate_b, concentration_a, rate_a): 158 h1 = self.g1.cross_entropy('Gamma', concentration_b, rate_b) 159 h2 = self.g2.cross_entropy('Gamma', concentration_b, rate_b, concentration_a, rate_a) 160 return h1 + h2 161 162def test_cross_entropy(): 163 """ 164 Test cross entropy between Gamma distributions. 165 """ 166 net = GammaCrossEntropy() 167 concentration_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 168 rate_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 169 concentration_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) 170 rate_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) 171 ans = net(concentration_b, rate_b, concentration_a, rate_a) 172 assert isinstance(ans, Tensor) 173 174class GammaBasics(nn.Cell): 175 """ 176 Test class: basic mean/sd function. 177 """ 178 def __init__(self): 179 super(GammaBasics, self).__init__() 180 self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32) 181 182 def construct(self): 183 mean = self.g.mean() 184 sd = self.g.sd() 185 mode = self.g.mode() 186 return mean + sd + mode 187 188def test_bascis(): 189 """ 190 Test mean/sd/mode/entropy functionality of Gamma. 191 """ 192 net = GammaBasics() 193 ans = net() 194 assert isinstance(ans, Tensor) 195 196class GammaConstruct(nn.Cell): 197 """ 198 Gamma distribution: going through construct. 199 """ 200 def __init__(self): 201 super(GammaConstruct, self).__init__() 202 self.gamma = msd.Gamma([3.0], [4.0]) 203 self.gamma1 = msd.Gamma() 204 205 def construct(self, value, concentration, rate): 206 prob = self.gamma('prob', value) 207 prob1 = self.gamma('prob', value, concentration, rate) 208 prob2 = self.gamma1('prob', value, concentration, rate) 209 return prob + prob1 + prob2 210 211def test_gamma_construct(): 212 """ 213 Test probability function going through construct. 214 """ 215 net = GammaConstruct() 216 value = Tensor([0.5, 1.0], dtype=dtype.float32) 217 concentration = Tensor([0.0], dtype=dtype.float32) 218 rate = Tensor([1.0], dtype=dtype.float32) 219 ans = net(value, concentration, rate) 220 assert isinstance(ans, Tensor) 221