1# Copyright 2019 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test cases for Logistic distribution""" 16import numpy as np 17from scipy import stats 18import mindspore.context as context 19import mindspore.nn as nn 20import mindspore.nn.probability.distribution as msd 21from mindspore import Tensor 22from mindspore import dtype 23 24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") 25 26class Prob(nn.Cell): 27 """ 28 Test class: probability of Logistic distribution. 29 """ 30 def __init__(self): 31 super(Prob, self).__init__() 32 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 33 34 def construct(self, x_): 35 return self.l.prob(x_) 36 37def test_pdf(): 38 """ 39 Test pdf. 40 """ 41 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 42 expect_pdf = logistic_benchmark.pdf([1.0, 2.0]).astype(np.float32) 43 pdf = Prob() 44 output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 45 tol = 1e-6 46 assert (np.abs(output.asnumpy() - expect_pdf) < tol).all() 47 48class LogProb(nn.Cell): 49 """ 50 Test class: log probability of Logistic distribution. 51 """ 52 def __init__(self): 53 super(LogProb, self).__init__() 54 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 55 56 def construct(self, x_): 57 return self.l.log_prob(x_) 58 59def test_log_likelihood(): 60 """ 61 Test log_pdf. 62 """ 63 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 64 expect_logpdf = logistic_benchmark.logpdf([1.0, 2.0]).astype(np.float32) 65 logprob = LogProb() 66 output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32)) 67 tol = 1e-6 68 assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all() 69 70class Basics(nn.Cell): 71 """ 72 Test class: mean/sd/mode of Logistic distribution. 73 """ 74 def __init__(self): 75 super(Basics, self).__init__() 76 self.l = msd.Logistic(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32) 77 78 def construct(self): 79 return self.l.mean(), self.l.sd(), self.l.mode() 80 81def test_basics(): 82 """ 83 Test mean/standard deviation/mode. 84 """ 85 basics = Basics() 86 mean, sd, mode = basics() 87 expect_mean = [3.0, 3.0] 88 expect_sd = np.pi * np.array([2.0, 4.0]) / np.sqrt(np.array([3.0])) 89 tol = 1e-6 90 assert (np.abs(mean.asnumpy() - expect_mean) < tol).all() 91 assert (np.abs(mode.asnumpy() - expect_mean) < tol).all() 92 assert (np.abs(sd.asnumpy() - expect_sd) < tol).all() 93 94class Sampling(nn.Cell): 95 """ 96 Test class: sample of Logistic distribution. 97 """ 98 def __init__(self, shape, seed=0): 99 super(Sampling, self).__init__() 100 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32) 101 self.shape = shape 102 103 def construct(self, mean=None, sd=None): 104 return self.l.sample(self.shape, mean, sd) 105 106def test_sample(): 107 """ 108 Test sample. 109 """ 110 shape = (2, 3) 111 seed = 10 112 mean = Tensor([2.0], dtype=dtype.float32) 113 sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32) 114 sample = Sampling(shape, seed=seed) 115 output = sample(mean, sd) 116 assert output.shape == (2, 3, 3) 117 118class CDF(nn.Cell): 119 """ 120 Test class: cdf of Logistic distribution. 121 """ 122 def __init__(self): 123 super(CDF, self).__init__() 124 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 125 126 def construct(self, x_): 127 return self.l.cdf(x_) 128 129 130def test_cdf(): 131 """ 132 Test cdf. 133 """ 134 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 135 expect_cdf = logistic_benchmark.cdf([1.0, 2.0]).astype(np.float32) 136 cdf = CDF() 137 output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 138 tol = 2e-5 139 assert (np.abs(output.asnumpy() - expect_cdf) < tol).all() 140 141class LogCDF(nn.Cell): 142 """ 143 Test class: log_cdf of Logistic distribution. 144 """ 145 def __init__(self): 146 super(LogCDF, self).__init__() 147 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 148 149 def construct(self, x_): 150 return self.l.log_cdf(x_) 151 152def test_log_cdf(): 153 """ 154 Test log cdf. 155 """ 156 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 157 expect_logcdf = logistic_benchmark.logcdf([1.0, 2.0]).astype(np.float32) 158 logcdf = LogCDF() 159 output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32)) 160 tol = 5e-5 161 assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all() 162 163class SF(nn.Cell): 164 """ 165 Test class: survival function of Logistic distribution. 166 """ 167 def __init__(self): 168 super(SF, self).__init__() 169 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 170 171 def construct(self, x_): 172 return self.l.survival_function(x_) 173 174def test_survival(): 175 """ 176 Test log_survival. 177 """ 178 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 179 expect_survival = logistic_benchmark.sf([1.0, 2.0]).astype(np.float32) 180 survival_function = SF() 181 output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32)) 182 tol = 2e-5 183 assert (np.abs(output.asnumpy() - expect_survival) < tol).all() 184 185class LogSF(nn.Cell): 186 """ 187 Test class: log survival function of Logistic distribution. 188 """ 189 def __init__(self): 190 super(LogSF, self).__init__() 191 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 192 193 def construct(self, x_): 194 return self.l.log_survival(x_) 195 196def test_log_survival(): 197 """ 198 Test log_survival. 199 """ 200 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 201 expect_log_survival = logistic_benchmark.logsf([1.0, 2.0]).astype(np.float32) 202 log_survival = LogSF() 203 output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32)) 204 tol = 2e-5 205 assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all() 206 207class EntropyH(nn.Cell): 208 """ 209 Test class: entropy of Logistic distribution. 210 """ 211 def __init__(self): 212 super(EntropyH, self).__init__() 213 self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32) 214 215 def construct(self): 216 return self.l.entropy() 217 218def test_entropy(): 219 """ 220 Test entropy. 221 """ 222 logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]])) 223 expect_entropy = logistic_benchmark.entropy().astype(np.float32) 224 entropy = EntropyH() 225 output = entropy() 226 tol = 1e-6 227 assert (np.abs(output.asnumpy() - expect_entropy) < tol).all() 228