1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.cauchy. 17""" 18import pytest 19 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import dtype 23from mindspore import Tensor 24 25def test_cauchy_shape_errpr(): 26 """ 27 Invalid shapes. 28 """ 29 with pytest.raises(ValueError): 30 msd.Cauchy([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) 31 32def test_type(): 33 with pytest.raises(TypeError): 34 msd.Cauchy(0., 1., dtype=dtype.int32) 35 36def test_name(): 37 with pytest.raises(TypeError): 38 msd.Cauchy(0., 1., name=1.0) 39 40def test_seed(): 41 with pytest.raises(TypeError): 42 msd.Cauchy(0., 1., seed='seed') 43 44def test_scale(): 45 with pytest.raises(ValueError): 46 msd.Cauchy(0., 0.) 47 with pytest.raises(ValueError): 48 msd.Cauchy(0., -1.) 49 50def test_arguments(): 51 """ 52 args passing during initialization. 53 """ 54 l = msd.Cauchy() 55 assert isinstance(l, msd.Distribution) 56 l = msd.Cauchy([3.0], [4.0], dtype=dtype.float32) 57 assert isinstance(l, msd.Distribution) 58 59 60class CauchyProb(nn.Cell): 61 """ 62 Cauchy distribution: initialize with loc/scale. 63 """ 64 def __init__(self): 65 super(CauchyProb, self).__init__() 66 self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32) 67 68 def construct(self, value): 69 prob = self.cauchy.prob(value) 70 log_prob = self.cauchy.log_prob(value) 71 cdf = self.cauchy.cdf(value) 72 log_cdf = self.cauchy.log_cdf(value) 73 sf = self.cauchy.survival_function(value) 74 log_sf = self.cauchy.log_survival(value) 75 return prob + log_prob + cdf + log_cdf + sf + log_sf 76 77def test_cauchy_prob(): 78 """ 79 Test probability functions: passing value through construct. 80 """ 81 net = CauchyProb() 82 value = Tensor([0.5, 1.0], dtype=dtype.float32) 83 ans = net(value) 84 assert isinstance(ans, Tensor) 85 86 87class CauchyProb1(nn.Cell): 88 """ 89 Cauchy distribution: initialize without loc/scale. 90 """ 91 def __init__(self): 92 super(CauchyProb1, self).__init__() 93 self.cauchy = msd.Cauchy() 94 95 def construct(self, value, mu, s): 96 prob = self.cauchy.prob(value, mu, s) 97 log_prob = self.cauchy.log_prob(value, mu, s) 98 cdf = self.cauchy.cdf(value, mu, s) 99 log_cdf = self.cauchy.log_cdf(value, mu, s) 100 sf = self.cauchy.survival_function(value, mu, s) 101 log_sf = self.cauchy.log_survival(value, mu, s) 102 return prob + log_prob + cdf + log_cdf + sf + log_sf 103 104def test_cauchy_prob1(): 105 """ 106 Test probability functions: passing loc/scale, value through construct. 107 """ 108 net = CauchyProb1() 109 value = Tensor([0.5, 1.0], dtype=dtype.float32) 110 mu = Tensor([0.0], dtype=dtype.float32) 111 s = Tensor([1.0], dtype=dtype.float32) 112 ans = net(value, mu, s) 113 assert isinstance(ans, Tensor) 114 115class KL(nn.Cell): 116 """ 117 Test kl_loss and cross entropy. 118 """ 119 def __init__(self): 120 super(KL, self).__init__() 121 self.cauchy = msd.Cauchy(3.0, 4.0) 122 self.cauchy1 = msd.Cauchy() 123 124 def construct(self, mu, s, mu_a, s_a): 125 kl = self.cauchy.kl_loss('Cauchy', mu, s) 126 kl1 = self.cauchy1.kl_loss('Cauchy', mu, s, mu_a, s_a) 127 cross_entropy = self.cauchy.cross_entropy('Cauchy', mu, s) 128 cross_entropy1 = self.cauchy.cross_entropy('Cauchy', mu, s, mu_a, s_a) 129 return kl + kl1 + cross_entropy + cross_entropy1 130 131def test_kl_cross_entropy(): 132 """ 133 Test kl_loss and cross_entropy. 134 """ 135 net = KL() 136 mu = Tensor([0.0], dtype=dtype.float32) 137 s = Tensor([1.0], dtype=dtype.float32) 138 mu_a = Tensor([0.0], dtype=dtype.float32) 139 s_a = Tensor([1.0], dtype=dtype.float32) 140 ans = net(mu, s, mu_a, s_a) 141 assert isinstance(ans, Tensor) 142 143 144class CauchyBasics(nn.Cell): 145 """ 146 Test class: basic loc/scale function. 147 """ 148 def __init__(self): 149 super(CauchyBasics, self).__init__() 150 self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32) 151 152 def construct(self): 153 mode = self.cauchy.mode() 154 entropy = self.cauchy.entropy() 155 return mode + entropy 156 157class CauchyMean(nn.Cell): 158 """ 159 Test class: basic loc/scale function. 160 """ 161 def __init__(self): 162 super(CauchyMean, self).__init__() 163 self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32) 164 165 def construct(self): 166 return self.cauchy.mean() 167 168class CauchyVar(nn.Cell): 169 """ 170 Test class: basic loc/scale function. 171 """ 172 def __init__(self): 173 super(CauchyVar, self).__init__() 174 self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32) 175 176 def construct(self): 177 return self.cauchy.var() 178 179class CauchySd(nn.Cell): 180 """ 181 Test class: basic loc/scale function. 182 """ 183 def __init__(self): 184 super(CauchySd, self).__init__() 185 self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32) 186 187 def construct(self): 188 return self.cauchy.sd() 189 190def test_bascis(): 191 """ 192 Test mean/sd/var/mode/entropy functionality of Cauchy. 193 """ 194 net = CauchyBasics() 195 ans = net() 196 assert isinstance(ans, Tensor) 197 with pytest.raises(ValueError): 198 net = CauchyMean() 199 ans = net() 200 with pytest.raises(ValueError): 201 net = CauchyVar() 202 ans = net() 203 with pytest.raises(ValueError): 204 net = CauchySd() 205 ans = net() 206 207class CauchyConstruct(nn.Cell): 208 """ 209 Cauchy distribution: going through construct. 210 """ 211 def __init__(self): 212 super(CauchyConstruct, self).__init__() 213 self.cauchy = msd.Cauchy(3.0, 4.0) 214 self.cauchy1 = msd.Cauchy() 215 216 def construct(self, value, mu, s): 217 prob = self.cauchy('prob', value) 218 prob1 = self.cauchy('prob', value, mu, s) 219 prob2 = self.cauchy1('prob', value, mu, s) 220 return prob + prob1 + prob2 221 222def test_cauchy_construct(): 223 """ 224 Test probability function going through construct. 225 """ 226 net = CauchyConstruct() 227 value = Tensor([0.5, 1.0], dtype=dtype.float32) 228 mu = Tensor([0.0], dtype=dtype.float32) 229 s = Tensor([1.0], dtype=dtype.float32) 230 ans = net(value, mu, s) 231 assert isinstance(ans, Tensor) 232