1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test nn.probability.distribution.logistic. 17""" 18import pytest 19 20import mindspore.nn as nn 21import mindspore.nn.probability.distribution as msd 22from mindspore import dtype 23from mindspore import Tensor 24 25def test_logistic_shape_errpr(): 26 """ 27 Invalid shapes. 28 """ 29 with pytest.raises(ValueError): 30 msd.Logistic([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32) 31 32def test_type(): 33 with pytest.raises(TypeError): 34 msd.Logistic(0., 1., dtype=dtype.int32) 35 36def test_name(): 37 with pytest.raises(TypeError): 38 msd.Logistic(0., 1., name=1.0) 39 40def test_seed(): 41 with pytest.raises(TypeError): 42 msd.Logistic(0., 1., seed='seed') 43 44def test_scale(): 45 with pytest.raises(ValueError): 46 msd.Logistic(0., 0.) 47 with pytest.raises(ValueError): 48 msd.Logistic(0., -1.) 49 50def test_arguments(): 51 """ 52 args passing during initialization. 53 """ 54 l = msd.Logistic() 55 assert isinstance(l, msd.Distribution) 56 l = msd.Logistic([3.0], [4.0], dtype=dtype.float32) 57 assert isinstance(l, msd.Distribution) 58 59 60class LogisticProb(nn.Cell): 61 """ 62 logistic distribution: initialize with loc/scale. 63 """ 64 def __init__(self): 65 super(LogisticProb, self).__init__() 66 self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32) 67 68 def construct(self, value): 69 prob = self.logistic.prob(value) 70 log_prob = self.logistic.log_prob(value) 71 cdf = self.logistic.cdf(value) 72 log_cdf = self.logistic.log_cdf(value) 73 sf = self.logistic.survival_function(value) 74 log_sf = self.logistic.log_survival(value) 75 return prob + log_prob + cdf + log_cdf + sf + log_sf 76 77def test_logistic_prob(): 78 """ 79 Test probability functions: passing value through construct. 80 """ 81 net = LogisticProb() 82 value = Tensor([0.5, 1.0], dtype=dtype.float32) 83 ans = net(value) 84 assert isinstance(ans, Tensor) 85 86 87class LogisticProb1(nn.Cell): 88 """ 89 logistic distribution: initialize without loc/scale. 90 """ 91 def __init__(self): 92 super(LogisticProb1, self).__init__() 93 self.logistic = msd.Logistic() 94 95 def construct(self, value, mu, s): 96 prob = self.logistic.prob(value, mu, s) 97 log_prob = self.logistic.log_prob(value, mu, s) 98 cdf = self.logistic.cdf(value, mu, s) 99 log_cdf = self.logistic.log_cdf(value, mu, s) 100 sf = self.logistic.survival_function(value, mu, s) 101 log_sf = self.logistic.log_survival(value, mu, s) 102 return prob + log_prob + cdf + log_cdf + sf + log_sf 103 104def test_logistic_prob1(): 105 """ 106 Test probability functions: passing loc/scale, value through construct. 107 """ 108 net = LogisticProb1() 109 value = Tensor([0.5, 1.0], dtype=dtype.float32) 110 mu = Tensor([0.0], dtype=dtype.float32) 111 s = Tensor([1.0], dtype=dtype.float32) 112 ans = net(value, mu, s) 113 assert isinstance(ans, Tensor) 114 115class KL(nn.Cell): 116 """ 117 Test kl_loss. Should raise NotImplementedError. 118 """ 119 def __init__(self): 120 super(KL, self).__init__() 121 self.logistic = msd.Logistic(3.0, 4.0) 122 123 def construct(self, mu, s): 124 kl = self.logistic.kl_loss('Logistic', mu, s) 125 return kl 126 127class Crossentropy(nn.Cell): 128 """ 129 Test cross entropy. Should raise NotImplementedError. 130 """ 131 def __init__(self): 132 super(Crossentropy, self).__init__() 133 self.logistic = msd.Logistic(3.0, 4.0) 134 135 def construct(self, mu, s): 136 cross_entropy = self.logistic.cross_entropy('Logistic', mu, s) 137 return cross_entropy 138 139 140class LogisticBasics(nn.Cell): 141 """ 142 Test class: basic loc/scale function. 143 """ 144 def __init__(self): 145 super(LogisticBasics, self).__init__() 146 self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32) 147 148 def construct(self): 149 mean = self.logistic.mean() 150 sd = self.logistic.sd() 151 mode = self.logistic.mode() 152 entropy = self.logistic.entropy() 153 return mean + sd + mode + entropy 154 155def test_bascis(): 156 """ 157 Test mean/sd/mode/entropy functionality of logistic. 158 """ 159 net = LogisticBasics() 160 ans = net() 161 assert isinstance(ans, Tensor) 162 mu = Tensor(1.0, dtype=dtype.float32) 163 s = Tensor(1.0, dtype=dtype.float32) 164 with pytest.raises(NotImplementedError): 165 kl = KL() 166 ans = kl(mu, s) 167 with pytest.raises(NotImplementedError): 168 crossentropy = Crossentropy() 169 ans = crossentropy(mu, s) 170 171class LogisticConstruct(nn.Cell): 172 """ 173 logistic distribution: going through construct. 174 """ 175 def __init__(self): 176 super(LogisticConstruct, self).__init__() 177 self.logistic = msd.Logistic(3.0, 4.0) 178 self.logistic1 = msd.Logistic() 179 180 def construct(self, value, mu, s): 181 prob = self.logistic('prob', value) 182 prob1 = self.logistic('prob', value, mu, s) 183 prob2 = self.logistic1('prob', value, mu, s) 184 return prob + prob1 + prob2 185 186def test_logistic_construct(): 187 """ 188 Test probability function going through construct. 189 """ 190 net = LogisticConstruct() 191 value = Tensor([0.5, 1.0], dtype=dtype.float32) 192 mu = Tensor([0.0], dtype=dtype.float32) 193 s = Tensor([1.0], dtype=dtype.float32) 194 ans = net(value, mu, s) 195 assert isinstance(ans, Tensor) 196