• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.cauchy.
17"""
18import pytest
19
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import dtype
23from mindspore import Tensor
24
25def test_cauchy_shape_errpr():
26    """
27    Invalid shapes.
28    """
29    with pytest.raises(ValueError):
30        msd.Cauchy([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
31
32def test_type():
33    with pytest.raises(TypeError):
34        msd.Cauchy(0., 1., dtype=dtype.int32)
35
36def test_name():
37    with pytest.raises(TypeError):
38        msd.Cauchy(0., 1., name=1.0)
39
40def test_seed():
41    with pytest.raises(TypeError):
42        msd.Cauchy(0., 1., seed='seed')
43
44def test_scale():
45    with pytest.raises(ValueError):
46        msd.Cauchy(0., 0.)
47    with pytest.raises(ValueError):
48        msd.Cauchy(0., -1.)
49
50def test_arguments():
51    """
52    args passing during initialization.
53    """
54    l = msd.Cauchy()
55    assert isinstance(l, msd.Distribution)
56    l = msd.Cauchy([3.0], [4.0], dtype=dtype.float32)
57    assert isinstance(l, msd.Distribution)
58
59
60class CauchyProb(nn.Cell):
61    """
62    Cauchy distribution: initialize with loc/scale.
63    """
64    def __init__(self):
65        super(CauchyProb, self).__init__()
66        self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
67
68    def construct(self, value):
69        prob = self.cauchy.prob(value)
70        log_prob = self.cauchy.log_prob(value)
71        cdf = self.cauchy.cdf(value)
72        log_cdf = self.cauchy.log_cdf(value)
73        sf = self.cauchy.survival_function(value)
74        log_sf = self.cauchy.log_survival(value)
75        return prob + log_prob + cdf + log_cdf + sf + log_sf
76
77def test_cauchy_prob():
78    """
79    Test probability functions: passing value through construct.
80    """
81    net = CauchyProb()
82    value = Tensor([0.5, 1.0], dtype=dtype.float32)
83    ans = net(value)
84    assert isinstance(ans, Tensor)
85
86
87class CauchyProb1(nn.Cell):
88    """
89    Cauchy distribution: initialize without loc/scale.
90    """
91    def __init__(self):
92        super(CauchyProb1, self).__init__()
93        self.cauchy = msd.Cauchy()
94
95    def construct(self, value, mu, s):
96        prob = self.cauchy.prob(value, mu, s)
97        log_prob = self.cauchy.log_prob(value, mu, s)
98        cdf = self.cauchy.cdf(value, mu, s)
99        log_cdf = self.cauchy.log_cdf(value, mu, s)
100        sf = self.cauchy.survival_function(value, mu, s)
101        log_sf = self.cauchy.log_survival(value, mu, s)
102        return prob + log_prob + cdf + log_cdf + sf + log_sf
103
104def test_cauchy_prob1():
105    """
106    Test probability functions: passing loc/scale, value through construct.
107    """
108    net = CauchyProb1()
109    value = Tensor([0.5, 1.0], dtype=dtype.float32)
110    mu = Tensor([0.0], dtype=dtype.float32)
111    s = Tensor([1.0], dtype=dtype.float32)
112    ans = net(value, mu, s)
113    assert isinstance(ans, Tensor)
114
115class KL(nn.Cell):
116    """
117    Test kl_loss and cross entropy.
118    """
119    def __init__(self):
120        super(KL, self).__init__()
121        self.cauchy = msd.Cauchy(3.0, 4.0)
122        self.cauchy1 = msd.Cauchy()
123
124    def construct(self, mu, s, mu_a, s_a):
125        kl = self.cauchy.kl_loss('Cauchy', mu, s)
126        kl1 = self.cauchy1.kl_loss('Cauchy', mu, s, mu_a, s_a)
127        cross_entropy = self.cauchy.cross_entropy('Cauchy', mu, s)
128        cross_entropy1 = self.cauchy.cross_entropy('Cauchy', mu, s, mu_a, s_a)
129        return kl + kl1 + cross_entropy + cross_entropy1
130
131def test_kl_cross_entropy():
132    """
133    Test kl_loss and cross_entropy.
134    """
135    net = KL()
136    mu = Tensor([0.0], dtype=dtype.float32)
137    s = Tensor([1.0], dtype=dtype.float32)
138    mu_a = Tensor([0.0], dtype=dtype.float32)
139    s_a = Tensor([1.0], dtype=dtype.float32)
140    ans = net(mu, s, mu_a, s_a)
141    assert isinstance(ans, Tensor)
142
143
144class CauchyBasics(nn.Cell):
145    """
146    Test class: basic loc/scale function.
147    """
148    def __init__(self):
149        super(CauchyBasics, self).__init__()
150        self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
151
152    def construct(self):
153        mode = self.cauchy.mode()
154        entropy = self.cauchy.entropy()
155        return mode + entropy
156
157class CauchyMean(nn.Cell):
158    """
159    Test class: basic loc/scale function.
160    """
161    def __init__(self):
162        super(CauchyMean, self).__init__()
163        self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
164
165    def construct(self):
166        return self.cauchy.mean()
167
168class CauchyVar(nn.Cell):
169    """
170    Test class: basic loc/scale function.
171    """
172    def __init__(self):
173        super(CauchyVar, self).__init__()
174        self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
175
176    def construct(self):
177        return self.cauchy.var()
178
179class CauchySd(nn.Cell):
180    """
181    Test class: basic loc/scale function.
182    """
183    def __init__(self):
184        super(CauchySd, self).__init__()
185        self.cauchy = msd.Cauchy(3.0, 4.0, dtype=dtype.float32)
186
187    def construct(self):
188        return self.cauchy.sd()
189
190def test_bascis():
191    """
192    Test mean/sd/var/mode/entropy functionality of Cauchy.
193    """
194    net = CauchyBasics()
195    ans = net()
196    assert isinstance(ans, Tensor)
197    with pytest.raises(ValueError):
198        net = CauchyMean()
199        ans = net()
200    with pytest.raises(ValueError):
201        net = CauchyVar()
202        ans = net()
203    with pytest.raises(ValueError):
204        net = CauchySd()
205        ans = net()
206
207class CauchyConstruct(nn.Cell):
208    """
209    Cauchy distribution: going through construct.
210    """
211    def __init__(self):
212        super(CauchyConstruct, self).__init__()
213        self.cauchy = msd.Cauchy(3.0, 4.0)
214        self.cauchy1 = msd.Cauchy()
215
216    def construct(self, value, mu, s):
217        prob = self.cauchy('prob', value)
218        prob1 = self.cauchy('prob', value, mu, s)
219        prob2 = self.cauchy1('prob', value, mu, s)
220        return prob + prob1 + prob2
221
222def test_cauchy_construct():
223    """
224    Test probability function going through construct.
225    """
226    net = CauchyConstruct()
227    value = Tensor([0.5, 1.0], dtype=dtype.float32)
228    mu = Tensor([0.0], dtype=dtype.float32)
229    s = Tensor([1.0], dtype=dtype.float32)
230    ans = net(value, mu, s)
231    assert isinstance(ans, Tensor)
232