• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.logistic.
17"""
18import pytest
19
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import dtype
23from mindspore import Tensor
24
25def test_logistic_shape_errpr():
26    """
27    Invalid shapes.
28    """
29    with pytest.raises(ValueError):
30        msd.Logistic([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
31
32def test_type():
33    with pytest.raises(TypeError):
34        msd.Logistic(0., 1., dtype=dtype.int32)
35
36def test_name():
37    with pytest.raises(TypeError):
38        msd.Logistic(0., 1., name=1.0)
39
40def test_seed():
41    with pytest.raises(TypeError):
42        msd.Logistic(0., 1., seed='seed')
43
44def test_scale():
45    with pytest.raises(ValueError):
46        msd.Logistic(0., 0.)
47    with pytest.raises(ValueError):
48        msd.Logistic(0., -1.)
49
50def test_arguments():
51    """
52    args passing during initialization.
53    """
54    l = msd.Logistic()
55    assert isinstance(l, msd.Distribution)
56    l = msd.Logistic([3.0], [4.0], dtype=dtype.float32)
57    assert isinstance(l, msd.Distribution)
58
59
60class LogisticProb(nn.Cell):
61    """
62    logistic distribution: initialize with loc/scale.
63    """
64    def __init__(self):
65        super(LogisticProb, self).__init__()
66        self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32)
67
68    def construct(self, value):
69        prob = self.logistic.prob(value)
70        log_prob = self.logistic.log_prob(value)
71        cdf = self.logistic.cdf(value)
72        log_cdf = self.logistic.log_cdf(value)
73        sf = self.logistic.survival_function(value)
74        log_sf = self.logistic.log_survival(value)
75        return prob + log_prob + cdf + log_cdf + sf + log_sf
76
77def test_logistic_prob():
78    """
79    Test probability functions: passing value through construct.
80    """
81    net = LogisticProb()
82    value = Tensor([0.5, 1.0], dtype=dtype.float32)
83    ans = net(value)
84    assert isinstance(ans, Tensor)
85
86
87class LogisticProb1(nn.Cell):
88    """
89    logistic distribution: initialize without loc/scale.
90    """
91    def __init__(self):
92        super(LogisticProb1, self).__init__()
93        self.logistic = msd.Logistic()
94
95    def construct(self, value, mu, s):
96        prob = self.logistic.prob(value, mu, s)
97        log_prob = self.logistic.log_prob(value, mu, s)
98        cdf = self.logistic.cdf(value, mu, s)
99        log_cdf = self.logistic.log_cdf(value, mu, s)
100        sf = self.logistic.survival_function(value, mu, s)
101        log_sf = self.logistic.log_survival(value, mu, s)
102        return prob + log_prob + cdf + log_cdf + sf + log_sf
103
104def test_logistic_prob1():
105    """
106    Test probability functions: passing loc/scale, value through construct.
107    """
108    net = LogisticProb1()
109    value = Tensor([0.5, 1.0], dtype=dtype.float32)
110    mu = Tensor([0.0], dtype=dtype.float32)
111    s = Tensor([1.0], dtype=dtype.float32)
112    ans = net(value, mu, s)
113    assert isinstance(ans, Tensor)
114
115class KL(nn.Cell):
116    """
117    Test kl_loss. Should raise NotImplementedError.
118    """
119    def __init__(self):
120        super(KL, self).__init__()
121        self.logistic = msd.Logistic(3.0, 4.0)
122
123    def construct(self, mu, s):
124        kl = self.logistic.kl_loss('Logistic', mu, s)
125        return kl
126
127class Crossentropy(nn.Cell):
128    """
129    Test cross entropy. Should raise NotImplementedError.
130    """
131    def __init__(self):
132        super(Crossentropy, self).__init__()
133        self.logistic = msd.Logistic(3.0, 4.0)
134
135    def construct(self, mu, s):
136        cross_entropy = self.logistic.cross_entropy('Logistic', mu, s)
137        return cross_entropy
138
139
140class LogisticBasics(nn.Cell):
141    """
142    Test class: basic loc/scale function.
143    """
144    def __init__(self):
145        super(LogisticBasics, self).__init__()
146        self.logistic = msd.Logistic(3.0, 4.0, dtype=dtype.float32)
147
148    def construct(self):
149        mean = self.logistic.mean()
150        sd = self.logistic.sd()
151        mode = self.logistic.mode()
152        entropy = self.logistic.entropy()
153        return mean + sd + mode + entropy
154
155def test_bascis():
156    """
157    Test mean/sd/mode/entropy functionality of logistic.
158    """
159    net = LogisticBasics()
160    ans = net()
161    assert isinstance(ans, Tensor)
162    mu = Tensor(1.0, dtype=dtype.float32)
163    s = Tensor(1.0, dtype=dtype.float32)
164    with pytest.raises(NotImplementedError):
165        kl = KL()
166        ans = kl(mu, s)
167    with pytest.raises(NotImplementedError):
168        crossentropy = Crossentropy()
169        ans = crossentropy(mu, s)
170
171class LogisticConstruct(nn.Cell):
172    """
173    logistic distribution: going through construct.
174    """
175    def __init__(self):
176        super(LogisticConstruct, self).__init__()
177        self.logistic = msd.Logistic(3.0, 4.0)
178        self.logistic1 = msd.Logistic()
179
180    def construct(self, value, mu, s):
181        prob = self.logistic('prob', value)
182        prob1 = self.logistic('prob', value, mu, s)
183        prob2 = self.logistic1('prob', value, mu, s)
184        return prob + prob1 + prob2
185
186def test_logistic_construct():
187    """
188    Test probability function going through construct.
189    """
190    net = LogisticConstruct()
191    value = Tensor([0.5, 1.0], dtype=dtype.float32)
192    mu = Tensor([0.0], dtype=dtype.float32)
193    s = Tensor([1.0], dtype=dtype.float32)
194    ans = net(value, mu, s)
195    assert isinstance(ans, Tensor)
196