• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""test cases for Logistic distribution"""
16import numpy as np
17from scipy import stats
18import mindspore.context as context
19import mindspore.nn as nn
20import mindspore.nn.probability.distribution as msd
21from mindspore import Tensor
22from mindspore import dtype
23
24context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
25
26class Prob(nn.Cell):
27    """
28    Test class: probability of Logistic distribution.
29    """
30    def __init__(self):
31        super(Prob, self).__init__()
32        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
33
34    def construct(self, x_):
35        return self.l.prob(x_)
36
37def test_pdf():
38    """
39    Test pdf.
40    """
41    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
42    expect_pdf = logistic_benchmark.pdf([1.0, 2.0]).astype(np.float32)
43    pdf = Prob()
44    output = pdf(Tensor([1.0, 2.0], dtype=dtype.float32))
45    tol = 1e-6
46    assert (np.abs(output.asnumpy() - expect_pdf) < tol).all()
47
48class LogProb(nn.Cell):
49    """
50    Test class: log probability of Logistic distribution.
51    """
52    def __init__(self):
53        super(LogProb, self).__init__()
54        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
55
56    def construct(self, x_):
57        return self.l.log_prob(x_)
58
59def test_log_likelihood():
60    """
61    Test log_pdf.
62    """
63    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
64    expect_logpdf = logistic_benchmark.logpdf([1.0, 2.0]).astype(np.float32)
65    logprob = LogProb()
66    output = logprob(Tensor([1.0, 2.0], dtype=dtype.float32))
67    tol = 1e-6
68    assert (np.abs(output.asnumpy() - expect_logpdf) < tol).all()
69
70class Basics(nn.Cell):
71    """
72    Test class: mean/sd/mode of Logistic distribution.
73    """
74    def __init__(self):
75        super(Basics, self).__init__()
76        self.l = msd.Logistic(np.array([3.0]), np.array([2.0, 4.0]), dtype=dtype.float32)
77
78    def construct(self):
79        return self.l.mean(), self.l.sd(), self.l.mode()
80
81def test_basics():
82    """
83    Test mean/standard deviation/mode.
84    """
85    basics = Basics()
86    mean, sd, mode = basics()
87    expect_mean = [3.0, 3.0]
88    expect_sd = np.pi * np.array([2.0, 4.0]) / np.sqrt(np.array([3.0]))
89    tol = 1e-6
90    assert (np.abs(mean.asnumpy() - expect_mean) < tol).all()
91    assert (np.abs(mode.asnumpy() - expect_mean) < tol).all()
92    assert (np.abs(sd.asnumpy() - expect_sd) < tol).all()
93
94class Sampling(nn.Cell):
95    """
96    Test class: sample of Logistic distribution.
97    """
98    def __init__(self, shape, seed=0):
99        super(Sampling, self).__init__()
100        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), seed=seed, dtype=dtype.float32)
101        self.shape = shape
102
103    def construct(self, mean=None, sd=None):
104        return self.l.sample(self.shape, mean, sd)
105
106def test_sample():
107    """
108    Test sample.
109    """
110    shape = (2, 3)
111    seed = 10
112    mean = Tensor([2.0], dtype=dtype.float32)
113    sd = Tensor([2.0, 2.0, 2.0], dtype=dtype.float32)
114    sample = Sampling(shape, seed=seed)
115    output = sample(mean, sd)
116    assert output.shape == (2, 3, 3)
117
118class CDF(nn.Cell):
119    """
120    Test class: cdf of Logistic distribution.
121    """
122    def __init__(self):
123        super(CDF, self).__init__()
124        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
125
126    def construct(self, x_):
127        return self.l.cdf(x_)
128
129
130def test_cdf():
131    """
132    Test cdf.
133    """
134    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
135    expect_cdf = logistic_benchmark.cdf([1.0, 2.0]).astype(np.float32)
136    cdf = CDF()
137    output = cdf(Tensor([1.0, 2.0], dtype=dtype.float32))
138    tol = 2e-5
139    assert (np.abs(output.asnumpy() - expect_cdf) < tol).all()
140
141class LogCDF(nn.Cell):
142    """
143    Test class: log_cdf of Logistic distribution.
144    """
145    def __init__(self):
146        super(LogCDF, self).__init__()
147        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
148
149    def construct(self, x_):
150        return self.l.log_cdf(x_)
151
152def test_log_cdf():
153    """
154    Test log cdf.
155    """
156    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
157    expect_logcdf = logistic_benchmark.logcdf([1.0, 2.0]).astype(np.float32)
158    logcdf = LogCDF()
159    output = logcdf(Tensor([1.0, 2.0], dtype=dtype.float32))
160    tol = 5e-5
161    assert (np.abs(output.asnumpy() - expect_logcdf) < tol).all()
162
163class SF(nn.Cell):
164    """
165    Test class: survival function of Logistic distribution.
166    """
167    def __init__(self):
168        super(SF, self).__init__()
169        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
170
171    def construct(self, x_):
172        return self.l.survival_function(x_)
173
174def test_survival():
175    """
176    Test log_survival.
177    """
178    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
179    expect_survival = logistic_benchmark.sf([1.0, 2.0]).astype(np.float32)
180    survival_function = SF()
181    output = survival_function(Tensor([1.0, 2.0], dtype=dtype.float32))
182    tol = 2e-5
183    assert (np.abs(output.asnumpy() - expect_survival) < tol).all()
184
185class LogSF(nn.Cell):
186    """
187    Test class: log survival function of Logistic distribution.
188    """
189    def __init__(self):
190        super(LogSF, self).__init__()
191        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
192
193    def construct(self, x_):
194        return self.l.log_survival(x_)
195
196def test_log_survival():
197    """
198    Test log_survival.
199    """
200    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
201    expect_log_survival = logistic_benchmark.logsf([1.0, 2.0]).astype(np.float32)
202    log_survival = LogSF()
203    output = log_survival(Tensor([1.0, 2.0], dtype=dtype.float32))
204    tol = 2e-5
205    assert (np.abs(output.asnumpy() - expect_log_survival) < tol).all()
206
207class EntropyH(nn.Cell):
208    """
209    Test class: entropy of Logistic distribution.
210    """
211    def __init__(self):
212        super(EntropyH, self).__init__()
213        self.l = msd.Logistic(np.array([3.0]), np.array([[2.0], [4.0]]), dtype=dtype.float32)
214
215    def construct(self):
216        return self.l.entropy()
217
218def test_entropy():
219    """
220    Test entropy.
221    """
222    logistic_benchmark = stats.logistic(np.array([3.0]), np.array([[2.0], [4.0]]))
223    expect_entropy = logistic_benchmark.entropy().astype(np.float32)
224    entropy = EntropyH()
225    output = entropy()
226    tol = 1e-6
227    assert (np.abs(output.asnumpy() - expect_entropy) < tol).all()
228