• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.Gamma.
17"""
18import numpy as np
19import pytest
20
21import mindspore.nn as nn
22import mindspore.nn.probability.distribution as msd
23from mindspore import dtype
24from mindspore import Tensor
25
26def test_gamma_shape_errpr():
27    """
28    Invalid shapes.
29    """
30    with pytest.raises(ValueError):
31        msd.Gamma([[2.], [1.]], [[2.], [3.], [4.]], dtype=dtype.float32)
32
33def test_type():
34    with pytest.raises(TypeError):
35        msd.Gamma([0.], [1.], dtype=dtype.int32)
36
37def test_name():
38    with pytest.raises(TypeError):
39        msd.Gamma([0.], [1.], name=1.0)
40
41def test_seed():
42    with pytest.raises(TypeError):
43        msd.Gamma([0.], [1.], seed='seed')
44
45def test_rate():
46    with pytest.raises(ValueError):
47        msd.Gamma([0.], [0.])
48    with pytest.raises(ValueError):
49        msd.Gamma([0.], [-1.])
50
51def test_scalar():
52    with pytest.raises(TypeError):
53        msd.Gamma(3., [4.])
54    with pytest.raises(TypeError):
55        msd.Gamma([3.], -4.)
56
57def test_arguments():
58    """
59    args passing during initialization.
60    """
61    g = msd.Gamma()
62    assert isinstance(g, msd.Distribution)
63    g = msd.Gamma([3.0], [4.0], dtype=dtype.float32)
64    assert isinstance(g, msd.Distribution)
65
66
67class GammaProb(nn.Cell):
68    """
69    Gamma distribution: initialize with concentration/rate.
70    """
71    def __init__(self):
72        super(GammaProb, self).__init__()
73        self.gamma = msd.Gamma([3.0, 4.0], [1.0, 1.0], dtype=dtype.float32)
74
75    def construct(self, value):
76        prob = self.gamma.prob(value)
77        log_prob = self.gamma.log_prob(value)
78        cdf = self.gamma.cdf(value)
79        log_cdf = self.gamma.log_cdf(value)
80        sf = self.gamma.survival_function(value)
81        log_sf = self.gamma.log_survival(value)
82        return prob + log_prob + cdf + log_cdf + sf + log_sf
83
84def test_gamma_prob():
85    """
86    Test probability functions: passing value through construct.
87    """
88    net = GammaProb()
89    value = Tensor([0.5, 1.0], dtype=dtype.float32)
90    ans = net(value)
91    assert isinstance(ans, Tensor)
92
93
94class GammaProb1(nn.Cell):
95    """
96    Gamma distribution: initialize without concentration/rate.
97    """
98    def __init__(self):
99        super(GammaProb1, self).__init__()
100        self.gamma = msd.Gamma()
101
102    def construct(self, value, concentration, rate):
103        prob = self.gamma.prob(value, concentration, rate)
104        log_prob = self.gamma.log_prob(value, concentration, rate)
105        cdf = self.gamma.cdf(value, concentration, rate)
106        log_cdf = self.gamma.log_cdf(value, concentration, rate)
107        sf = self.gamma.survival_function(value, concentration, rate)
108        log_sf = self.gamma.log_survival(value, concentration, rate)
109        return prob + log_prob + cdf + log_cdf + sf + log_sf
110
111def test_gamma_prob1():
112    """
113    Test probability functions: passing concentration/rate, value through construct.
114    """
115    net = GammaProb1()
116    value = Tensor([0.5, 1.0], dtype=dtype.float32)
117    concentration = Tensor([2.0, 3.0], dtype=dtype.float32)
118    rate = Tensor([1.0], dtype=dtype.float32)
119    ans = net(value, concentration, rate)
120    assert isinstance(ans, Tensor)
121
122class GammaKl(nn.Cell):
123    """
124    Test class: kl_loss of Gamma distribution.
125    """
126    def __init__(self):
127        super(GammaKl, self).__init__()
128        self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
129        self.g2 = msd.Gamma(dtype=dtype.float32)
130
131    def construct(self, concentration_b, rate_b, concentration_a, rate_a):
132        kl1 = self.g1.kl_loss('Gamma', concentration_b, rate_b)
133        kl2 = self.g2.kl_loss('Gamma', concentration_b, rate_b, concentration_a, rate_a)
134        return kl1 + kl2
135
136def test_kl():
137    """
138    Test kl_loss.
139    """
140    net = GammaKl()
141    concentration_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
142    rate_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
143    concentration_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
144    rate_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
145    ans = net(concentration_b, rate_b, concentration_a, rate_a)
146    assert isinstance(ans, Tensor)
147
148class GammaCrossEntropy(nn.Cell):
149    """
150    Test class: cross_entropy of Gamma distribution.
151    """
152    def __init__(self):
153        super(GammaCrossEntropy, self).__init__()
154        self.g1 = msd.Gamma(np.array([3.0]), np.array([4.0]), dtype=dtype.float32)
155        self.g2 = msd.Gamma(dtype=dtype.float32)
156
157    def construct(self, concentration_b, rate_b, concentration_a, rate_a):
158        h1 = self.g1.cross_entropy('Gamma', concentration_b, rate_b)
159        h2 = self.g2.cross_entropy('Gamma', concentration_b, rate_b, concentration_a, rate_a)
160        return h1 + h2
161
162def test_cross_entropy():
163    """
164    Test cross entropy between Gamma distributions.
165    """
166    net = GammaCrossEntropy()
167    concentration_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
168    rate_b = Tensor(np.array([1.0]).astype(np.float32), dtype=dtype.float32)
169    concentration_a = Tensor(np.array([2.0]).astype(np.float32), dtype=dtype.float32)
170    rate_a = Tensor(np.array([3.0]).astype(np.float32), dtype=dtype.float32)
171    ans = net(concentration_b, rate_b, concentration_a, rate_a)
172    assert isinstance(ans, Tensor)
173
174class GammaBasics(nn.Cell):
175    """
176    Test class: basic mean/sd function.
177    """
178    def __init__(self):
179        super(GammaBasics, self).__init__()
180        self.g = msd.Gamma(np.array([3.0, 4.0]), np.array([4.0, 6.0]), dtype=dtype.float32)
181
182    def construct(self):
183        mean = self.g.mean()
184        sd = self.g.sd()
185        mode = self.g.mode()
186        return mean + sd + mode
187
188def test_bascis():
189    """
190    Test mean/sd/mode/entropy functionality of Gamma.
191    """
192    net = GammaBasics()
193    ans = net()
194    assert isinstance(ans, Tensor)
195
196class GammaConstruct(nn.Cell):
197    """
198    Gamma distribution: going through construct.
199    """
200    def __init__(self):
201        super(GammaConstruct, self).__init__()
202        self.gamma = msd.Gamma([3.0], [4.0])
203        self.gamma1 = msd.Gamma()
204
205    def construct(self, value, concentration, rate):
206        prob = self.gamma('prob', value)
207        prob1 = self.gamma('prob', value, concentration, rate)
208        prob2 = self.gamma1('prob', value, concentration, rate)
209        return prob + prob1 + prob2
210
211def test_gamma_construct():
212    """
213    Test probability function going through construct.
214    """
215    net = GammaConstruct()
216    value = Tensor([0.5, 1.0], dtype=dtype.float32)
217    concentration = Tensor([0.0], dtype=dtype.float32)
218    rate = Tensor([1.0], dtype=dtype.float32)
219    ans = net(value, concentration, rate)
220    assert isinstance(ans, Tensor)
221