1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.Gamma. 17""" 18import numpy as np 19import pytest 20 21import mindspore.nn as nn 22import mindspore.nn.probability.distribution as msd 23from mindspore import dtype 24from mindspore import Tensor 25 26def test_gamma_shape_errpr(): 27 """ 28 Invalid shapes. 29 """ 30 with pytest.raises(ValueError): 31 msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) 32 33def test_type(): 34 with pytest.raises(TypeError): 35 msd.Gamma([0.], [1.], dtype=dtype.int32) 36 37def test_name(): 38 with pytest.raises(TypeError): 39 msd.Gamma([0.], [1.], name=1.0) 40 41def test_seed(): 42 with pytest.raises(TypeError): 43 msd.Gamma([0.], [1.], seed='seed') 44 45def test_concentration1(): 46 with pytest.raises(ValueError): 47 msd.Gamma([0.], [1.]) 48 with pytest.raises(ValueError): 49 msd.Gamma([-1.], [1.]) 50 51def test_concentration0(): 52 with pytest.raises(ValueError): 53 msd.Gamma([1.], [0.]) 54 with pytest.raises(ValueError): 55 msd.Gamma([1.], [-1.]) 56 57def test_scalar(): 58 with pytest.raises(TypeError): 59 msd.Gamma(3., [4.]) 60 with pytest.raises(TypeError): 61 msd.Gamma([3.], -4.) 62 63def test_arguments(): 64 """ 65 args passing during initialization. 66 """ 67 g = msd.Gamma() 68 assert isinstance(g, msd.Distribution) 69 g = msd.Gamma([3.0], [4.0], dtype=dtype.float32) 70 assert isinstance(g, msd.Distribution) 71 72 73class GammaProb(nn.Cell): 74 """ 75 Gamma distribution: initialize with concentration1/concentration0. 76 """ 77 def __init__(self): 78 super(GammaProb, self).__init__() 79 self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32) 80 81 def construct(self, value): 82 prob = self.gamma.prob(value) 83 log_prob = self.gamma.log_prob(value) 84 return prob + log_prob 85 86def test_gamma_prob(): 87 """ 88 Test probability functions: passing value through construct. 89 """ 90 net = GammaProb() 91 value = Tensor([0.5, 1.0], dtype=dtype.float32) 92 ans = net(value) 93 assert isinstance(ans, Tensor) 94 95 96class GammaProb1(nn.Cell): 97 """ 98 Gamma distribution: initialize without concentration1/concentration0. 99 """ 100 def __init__(self): 101 super(GammaProb1, self).__init__() 102 self.gamma = msd.Gamma() 103 104 def construct(self, value, concentration1, concentration0): 105 prob = self.gamma.prob(value, concentration1, concentration0) 106 log_prob = self.gamma.log_prob(value, concentration1, concentration0) 107 return prob + log_prob 108 109def test_gamma_prob1(): 110 """ 111 Test probability functions: passing concentration1/concentration0, value through construct. 112 """ 113 net = GammaProb1() 114 value = Tensor([0.5, 1.0], dtype=dtype.float32) 115 concentration1 = Tensor([2.0, 3.0], dtype=dtype.float32) 116 concentration0 = Tensor([1.0], dtype=dtype.float32) 117 ans = net(value, concentration1, concentration0) 118 assert isinstance(ans, Tensor) 119 120class GammaKl(nn.Cell): 121 """ 122 Test class: kl_loss of Gamma distribution. 123 """ 124 def __init__(self): 125 super(GammaKl, self).__init__() 126 self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 127 self.g2 = msd.Gamma(dtype=dtype.float32) 128 129 def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): 130 kl1 = self.g1.kl_loss('Gamma', concentration1_b, concentration0_b) 131 kl2 = self.g2.kl_loss('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) 132 return kl1 + kl2 133 134def test_kl(): 135 """ 136 Test kl_loss. 137 """ 138 net = GammaKl() 139 concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 140 concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 141 concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) 142 concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) 143 ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) 144 assert isinstance(ans, Tensor) 145 146class GammaCrossEntropy(nn.Cell): 147 """ 148 Test class: cross_entropy of Gamma distribution. 149 """ 150 def __init__(self): 151 super(GammaCrossEntropy, self).__init__() 152 self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32) 153 self.g2 = msd.Gamma(dtype=dtype.float32) 154 155 def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a): 156 h1 = self.g1.cross_entropy('Gamma', concentration1_b, concentration0_b) 157 h2 = self.g2.cross_entropy('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a) 158 return h1 + h2 159 160def test_cross_entropy(): 161 """ 162 Test cross entropy between Gamma distributions. 163 """ 164 net = GammaCrossEntropy() 165 concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 166 concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32) 167 concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32) 168 concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32) 169 ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a) 170 assert isinstance(ans, Tensor) 171 172class GammaBasics(nn.Cell): 173 """ 174 Test class: basic mean/sd function. 175 """ 176 def __init__(self): 177 super(GammaBasics, self).__init__() 178 self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32) 179 180 def construct(self): 181 mean = self.g.mean() 182 sd = self.g.sd() 183 mode = self.g.mode() 184 return mean + sd + mode 185 186def test_bascis(): 187 """ 188 Test mean/sd/mode/entropy functionality of Gamma. 189 """ 190 net = GammaBasics() 191 ans = net() 192 assert isinstance(ans, Tensor) 193 194class GammaConstruct(nn.Cell): 195 """ 196 Gamma distribution: going through construct. 197 """ 198 def __init__(self): 199 super(GammaConstruct, self).__init__() 200 self.gamma = msd.Gamma([3.0], [4.0]) 201 self.gamma1 = msd.Gamma() 202 203 def construct(self, value, concentration1, concentration0): 204 prob = self.gamma('prob', value) 205 prob1 = self.gamma('prob', value, concentration1, concentration0) 206 prob2 = self.gamma1('prob', value, concentration1, concentration0) 207 return prob + prob1 + prob2 208 209def test_gamma_construct(): 210 """ 211 Test probability function going through construct. 212 """ 213 net = GammaConstruct() 214 value = Tensor([0.5, 1.0], dtype=dtype.float32) 215 concentration1 = Tensor([0.0], dtype=dtype.float32) 216 concentration0 = Tensor([1.0], dtype=dtype.float32) 217 ans = net(value, concentration1, concentration0) 218 assert isinstance(ans, Tensor) 219