• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.Gamma.
17"""
18import numpy as np
19import pytest
20
21import mindspore.nn as nn
22import mindspore.nn.probability.distribution as msd
23from mindspore import dtype
24from mindspore import Tensor
25
26def test_gamma_shape_errpr():
27    """
28    Invalid shapes.
29    """
30    with pytest.raises(ValueError):
31        msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
32
33def test_type():
34    with pytest.raises(TypeError):
35        msd.Gamma([0.], [1.], dtype=dtype.int32)
36
37def test_name():
38    with pytest.raises(TypeError):
39        msd.Gamma([0.], [1.], name=1.0)
40
41def test_seed():
42    with pytest.raises(TypeError):
43        msd.Gamma([0.], [1.], seed='seed')
44
45def test_concentration1():
46    with pytest.raises(ValueError):
47        msd.Gamma([0.], [1.])
48    with pytest.raises(ValueError):
49        msd.Gamma([-1.], [1.])
50
51def test_concentration0():
52    with pytest.raises(ValueError):
53        msd.Gamma([1.], [0.])
54    with pytest.raises(ValueError):
55        msd.Gamma([1.], [-1.])
56
57def test_scalar():
58    with pytest.raises(TypeError):
59        msd.Gamma(3., [4.])
60    with pytest.raises(TypeError):
61        msd.Gamma([3.], -4.)
62
63def test_arguments():
64    """
65    args passing during initialization.
66    """
67    g = msd.Gamma()
68    assert isinstance(g, msd.Distribution)
69    g = msd.Gamma([3.0], [4.0], dtype=dtype.float32)
70    assert isinstance(g, msd.Distribution)
71
72
73class GammaProb(nn.Cell):
74    """
75    Gamma distribution: initialize with concentration1/concentration0.
76    """
77    def __init__(self):
78        super(GammaProb, self).__init__()
79        self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32)
80
81    def construct(self, value):
82        prob = self.gamma.prob(value)
83        log_prob = self.gamma.log_prob(value)
84        return prob + log_prob
85
86def test_gamma_prob():
87    """
88    Test probability functions: passing value through construct.
89    """
90    net = GammaProb()
91    value = Tensor([0.5, 1.0], dtype=dtype.float32)
92    ans = net(value)
93    assert isinstance(ans, Tensor)
94
95
96class GammaProb1(nn.Cell):
97    """
98    Gamma distribution: initialize without concentration1/concentration0.
99    """
100    def __init__(self):
101        super(GammaProb1, self).__init__()
102        self.gamma = msd.Gamma()
103
104    def construct(self, value, concentration1, concentration0):
105        prob = self.gamma.prob(value, concentration1, concentration0)
106        log_prob = self.gamma.log_prob(value, concentration1, concentration0)
107        return prob + log_prob
108
109def test_gamma_prob1():
110    """
111    Test probability functions: passing concentration1/concentration0, value through construct.
112    """
113    net = GammaProb1()
114    value = Tensor([0.5, 1.0], dtype=dtype.float32)
115    concentration1 = Tensor([2.0, 3.0], dtype=dtype.float32)
116    concentration0 = Tensor([1.0], dtype=dtype.float32)
117    ans = net(value, concentration1, concentration0)
118    assert isinstance(ans, Tensor)
119
120class GammaKl(nn.Cell):
121    """
122    Test class: kl_loss of Gamma distribution.
123    """
124    def __init__(self):
125        super(GammaKl, self).__init__()
126        self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
127        self.g2 = msd.Gamma(dtype=dtype.float32)
128
129    def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a):
130        kl1 = self.g1.kl_loss('Gamma', concentration1_b, concentration0_b)
131        kl2 = self.g2.kl_loss('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a)
132        return kl1 + kl2
133
134def test_kl():
135    """
136    Test kl_loss.
137    """
138    net = GammaKl()
139    concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
140    concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
141    concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
142    concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
143    ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a)
144    assert isinstance(ans, Tensor)
145
146class GammaCrossEntropy(nn.Cell):
147    """
148    Test class: cross_entropy of Gamma distribution.
149    """
150    def __init__(self):
151        super(GammaCrossEntropy, self).__init__()
152        self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
153        self.g2 = msd.Gamma(dtype=dtype.float32)
154
155    def construct(self, concentration1_b, concentration0_b, concentration1_a, concentration0_a):
156        h1 = self.g1.cross_entropy('Gamma', concentration1_b, concentration0_b)
157        h2 = self.g2.cross_entropy('Gamma', concentration1_b, concentration0_b, concentration1_a, concentration0_a)
158        return h1 + h2
159
160def test_cross_entropy():
161    """
162    Test cross entropy between Gamma distributions.
163    """
164    net = GammaCrossEntropy()
165    concentration1_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
166    concentration0_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
167    concentration1_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
168    concentration0_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
169    ans = net(concentration1_b, concentration0_b, concentration1_a, concentration0_a)
170    assert isinstance(ans, Tensor)
171
172class GammaBasics(nn.Cell):
173    """
174    Test class: basic mean/sd function.
175    """
176    def __init__(self):
177        super(GammaBasics, self).__init__()
178        self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32)
179
180    def construct(self):
181        mean = self.g.mean()
182        sd = self.g.sd()
183        mode = self.g.mode()
184        return mean + sd + mode
185
186def test_bascis():
187    """
188    Test mean/sd/mode/entropy functionality of Gamma.
189    """
190    net = GammaBasics()
191    ans = net()
192    assert isinstance(ans, Tensor)
193
194class GammaConstruct(nn.Cell):
195    """
196    Gamma distribution: going through construct.
197    """
198    def __init__(self):
199        super(GammaConstruct, self).__init__()
200        self.gamma = msd.Gamma([3.0], [4.0])
201        self.gamma1 = msd.Gamma()
202
203    def construct(self, value, concentration1, concentration0):
204        prob = self.gamma('prob', value)
205        prob1 = self.gamma('prob', value, concentration1, concentration0)
206        prob2 = self.gamma1('prob', value, concentration1, concentration0)
207        return prob + prob1 + prob2
208
209def test_gamma_construct():
210    """
211    Test probability function going through construct.
212    """
213    net = GammaConstruct()
214    value = Tensor([0.5, 1.0], dtype=dtype.float32)
215    concentration1 = Tensor([0.0], dtype=dtype.float32)
216    concentration0 = Tensor([1.0], dtype=dtype.float32)
217    ans = net(value, concentration1, concentration0)
218    assert isinstance(ans, Tensor)
219