• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test nn.probability.distribution.
17"""
18import pytest
19
20import mindspore.nn as nn
21import mindspore.nn.probability.distribution as msd
22from mindspore import dtype as mstype
23from mindspore import Tensor
24from mindspore import context
25
26func_name_list = ['prob', 'log_prob', 'cdf', 'log_cdf',
27                  'survival_function', 'log_survival',
28                  'sd', 'var', 'mode', 'mean',
29                  'entropy', 'kl_loss', 'cross_entropy',
30                  'sample']
31
32
33class MyExponential(msd.Distribution):
34    """
35    Test distribution class: no function is implemented.
36    """
37
38    def __init__(self, rate=None, seed=None, dtype=mstype.float32, name="MyExponential"):
39        param = dict(locals())
40        param['param_dict'] = {'rate': rate}
41        super(MyExponential, self).__init__(seed, dtype, name, param)
42
43
44class Net(nn.Cell):
45    """
46    Test Net: function called through construct.
47    """
48
49    def __init__(self, func_name):
50        super(Net, self).__init__()
51        self.dist = MyExponential()
52        self.name = func_name
53
54    def construct(self, *args, **kwargs):
55        return self.dist(self.name, *args, **kwargs)
56
57
58def test_raise_not_implemented_error_construct():
59    """
60    test raise not implemented error in pynative mode.
61    """
62    value = Tensor([0.2], dtype=mstype.float32)
63    for func_name in func_name_list:
64        with pytest.raises(NotImplementedError):
65            net = Net(func_name)
66            net(value)
67
68
69def test_raise_not_implemented_error_construct_graph_mode():
70    """
71    test raise not implemented error in graph mode.
72    """
73    context.set_context(mode=context.GRAPH_MODE)
74    value = Tensor([0.2], dtype=mstype.float32)
75    for func_name in func_name_list:
76        with pytest.raises(NotImplementedError):
77            net = Net(func_name)
78            net(value)
79
80
81class Net1(nn.Cell):
82    """
83    Test Net: function called directly.
84    """
85
86    def __init__(self, func_name):
87        super(Net1, self).__init__()
88        self.dist = MyExponential()
89        self.func = getattr(self.dist, func_name)
90
91    def construct(self, *args, **kwargs):
92        return self.func(*args, **kwargs)
93
94
95def test_raise_not_implemented_error():
96    """
97    test raise not implemented error in pynative mode.
98    """
99    value = Tensor([0.2], dtype=mstype.float32)
100    for func_name in func_name_list:
101        with pytest.raises(NotImplementedError):
102            net = Net1(func_name)
103            net(value)
104
105
106def test_raise_not_implemented_error_graph_mode():
107    """
108    test raise not implemented error in graph mode.
109    """
110    context.set_context(mode=context.GRAPH_MODE)
111    value = Tensor([0.2], dtype=mstype.float32)
112    for func_name in func_name_list:
113        with pytest.raises(NotImplementedError):
114            net = Net1(func_name)
115            net(value)
116