1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.Exponential. 17""" 18import pytest 19 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import dtype 23from mindspore import Tensor 24 25 26def test_arguments(): 27 """ 28 Args passing during initialization. 29 """ 30 e = msd.Exponential() 31 assert isinstance(e, msd.Distribution) 32 e = msd.Exponential([0.1, 0.3, 0.5, 1.0], dtype=dtype.float32) 33 assert isinstance(e, msd.Distribution) 34 35def test_type(): 36 with pytest.raises(TypeError): 37 msd.Exponential([0.1], dtype=dtype.int32) 38 39def test_name(): 40 with pytest.raises(TypeError): 41 msd.Exponential([0.1], name=1.0) 42 43def test_seed(): 44 with pytest.raises(TypeError): 45 msd.Exponential([0.1], seed='seed') 46 47def test_rate(): 48 """ 49 Invalid rate. 50 """ 51 with pytest.raises(ValueError): 52 msd.Exponential([-0.1], dtype=dtype.float32) 53 with pytest.raises(ValueError): 54 msd.Exponential([0.0], dtype=dtype.float32) 55 56class ExponentialProb(nn.Cell): 57 """ 58 Exponential distribution: initialize with rate. 59 """ 60 def __init__(self): 61 super(ExponentialProb, self).__init__() 62 self.e = msd.Exponential(0.5, dtype=dtype.float32) 63 64 def construct(self, value): 65 prob = self.e.prob(value) 66 log_prob = self.e.log_prob(value) 67 cdf = self.e.cdf(value) 68 log_cdf = self.e.log_cdf(value) 69 sf = self.e.survival_function(value) 70 log_sf = self.e.log_survival(value) 71 return prob + log_prob + cdf + log_cdf + sf + log_sf 72 73def test_exponential_prob(): 74 """ 75 Test probability functions: passing value through construct. 76 """ 77 net = ExponentialProb() 78 value = Tensor([0.2, 0.3, 5.0, 2, 3.9], dtype=dtype.float32) 79 ans = net(value) 80 assert isinstance(ans, Tensor) 81 82class ExponentialProb1(nn.Cell): 83 """ 84 Exponential distribution: initialize without rate. 85 """ 86 def __init__(self): 87 super(ExponentialProb1, self).__init__() 88 self.e = msd.Exponential(dtype=dtype.float32) 89 90 def construct(self, value, rate): 91 prob = self.e.prob(value, rate) 92 log_prob = self.e.log_prob(value, rate) 93 cdf = self.e.cdf(value, rate) 94 log_cdf = self.e.log_cdf(value, rate) 95 sf = self.e.survival_function(value, rate) 96 log_sf = self.e.log_survival(value, rate) 97 return prob + log_prob + cdf + log_cdf + sf + log_sf 98 99def test_exponential_prob1(): 100 """ 101 Test probability functions: passing value/rate through construct. 102 """ 103 net = ExponentialProb1() 104 value = Tensor([0.2, 0.9, 1, 2, 3], dtype=dtype.float32) 105 rate = Tensor([0.5], dtype=dtype.float32) 106 ans = net(value, rate) 107 assert isinstance(ans, Tensor) 108 109class ExponentialKl(nn.Cell): 110 """ 111 Test class: kl_loss between Exponential distributions. 112 """ 113 def __init__(self): 114 super(ExponentialKl, self).__init__() 115 self.e1 = msd.Exponential(0.7, dtype=dtype.float32) 116 self.e2 = msd.Exponential(dtype=dtype.float32) 117 118 def construct(self, rate_b, rate_a): 119 kl1 = self.e1.kl_loss('Exponential', rate_b) 120 kl2 = self.e2.kl_loss('Exponential', rate_b, rate_a) 121 return kl1 + kl2 122 123def test_kl(): 124 """ 125 Test kl_loss function. 126 """ 127 net = ExponentialKl() 128 rate_b = Tensor([0.3], dtype=dtype.float32) 129 rate_a = Tensor([0.7], dtype=dtype.float32) 130 ans = net(rate_b, rate_a) 131 assert isinstance(ans, Tensor) 132 133class ExponentialCrossEntropy(nn.Cell): 134 """ 135 Test class: cross_entropy of Exponential distribution. 136 """ 137 def __init__(self): 138 super(ExponentialCrossEntropy, self).__init__() 139 self.e1 = msd.Exponential(0.3, dtype=dtype.float32) 140 self.e2 = msd.Exponential(dtype=dtype.float32) 141 142 def construct(self, rate_b, rate_a): 143 h1 = self.e1.cross_entropy('Exponential', rate_b) 144 h2 = self.e2.cross_entropy('Exponential', rate_b, rate_a) 145 return h1 + h2 146 147def test_cross_entropy(): 148 """ 149 Test cross_entropy between Exponential distributions. 150 """ 151 net = ExponentialCrossEntropy() 152 rate_b = Tensor([0.3], dtype=dtype.float32) 153 rate_a = Tensor([0.7], dtype=dtype.float32) 154 ans = net(rate_b, rate_a) 155 assert isinstance(ans, Tensor) 156 157class ExponentialBasics(nn.Cell): 158 """ 159 Test class: basic mean/sd/mode/entropy function. 160 """ 161 def __init__(self): 162 super(ExponentialBasics, self).__init__() 163 self.e = msd.Exponential([0.3, 0.5], dtype=dtype.float32) 164 165 def construct(self): 166 mean = self.e.mean() 167 sd = self.e.sd() 168 var = self.e.var() 169 mode = self.e.mode() 170 entropy = self.e.entropy() 171 return mean + sd + var + mode + entropy 172 173def test_bascis(): 174 """ 175 Test mean/sd/var/mode/entropy functionality of Exponential distribution. 176 """ 177 net = ExponentialBasics() 178 ans = net() 179 assert isinstance(ans, Tensor) 180 181 182class ExpConstruct(nn.Cell): 183 """ 184 Exponential distribution: going through construct. 185 """ 186 def __init__(self): 187 super(ExpConstruct, self).__init__() 188 self.e = msd.Exponential(0.5, dtype=dtype.float32) 189 self.e1 = msd.Exponential(dtype=dtype.float32) 190 191 def construct(self, value, rate): 192 prob = self.e('prob', value) 193 prob1 = self.e('prob', value, rate) 194 prob2 = self.e1('prob', value, rate) 195 return prob + prob1 + prob2 196 197def test_exp_construct(): 198 """ 199 Test probability function going through construct. 200 """ 201 net = ExpConstruct() 202 value = Tensor([0, 0, 0, 0, 0], dtype=dtype.float32) 203 probs = Tensor([0.5], dtype=dtype.float32) 204 ans = net(value, probs) 205 assert isinstance(ans, Tensor) 206