• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""
16Test util functions used in distribution classes.
17"""
18import numpy as np
19import pytest
20
21from mindspore.nn.cell import Cell
22from mindspore import context
23from mindspore import dtype
24from mindspore import Tensor
25from mindspore.common.parameter import Parameter
26from mindspore.nn.probability.distribution._utils.utils import set_param_type, \
27    cast_to_tensor, CheckTuple, CheckTensor
28
29def test_set_param_type():
30    """
31    Test set_param_type function.
32    """
33    tensor_fp16 = Tensor(0.1, dtype=dtype.float16)
34    tensor_fp32 = Tensor(0.1, dtype=dtype.float32)
35    tensor_fp64 = Tensor(0.1, dtype=dtype.float64)
36    tensor_int32 = Tensor(0.1, dtype=dtype.int32)
37    array_fp32 = np.array(1.0).astype(np.float32)
38    array_fp64 = np.array(1.0).astype(np.float64)
39    array_int32 = np.array(1.0).astype(np.int32)
40
41    dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32}
42    dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64}
43    dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32}
44    dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32}
45    dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64}
46    dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32}
47    dict7 = {'a': 1.0}
48    dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0}
49    dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16}
50    dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64}
51    dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64}
52
53    ans1 = set_param_type(dict1, dtype.float16)
54    assert ans1 == dtype.float32
55
56    with pytest.raises(TypeError):
57        set_param_type(dict2, dtype.float32)
58
59    ans3 = set_param_type(dict3, dtype.float16)
60    assert ans3 == dtype.float32
61    ans4 = set_param_type(dict4, dtype.float16)
62    assert ans4 == dtype.float32
63
64    with pytest.raises(TypeError):
65        set_param_type(dict5, dtype.float32)
66    with pytest.raises(TypeError):
67        set_param_type(dict6, dtype.float32)
68
69    ans7 = set_param_type(dict7, dtype.float32)
70    assert ans7 == dtype.float32
71    ans8 = set_param_type(dict8, dtype.float32)
72    assert ans8 == dtype.float32
73    ans9 = set_param_type(dict9, dtype.float32)
74    assert ans9 == dtype.float16
75    ans10 = set_param_type(dict10, dtype.float32)
76    assert ans10 == dtype.float32
77    ans11 = set_param_type(dict11, dtype.float32)
78    assert ans11 == dtype.float32
79
80def test_cast_to_tensor():
81    """
82    Test cast_to_tensor.
83    """
84    with pytest.raises(ValueError):
85        cast_to_tensor(None, dtype.float32)
86    with pytest.raises(TypeError):
87        cast_to_tensor(True, dtype.float32)
88    with pytest.raises(TypeError):
89        cast_to_tensor({'a': 1, 'b': 2}, dtype.float32)
90    with pytest.raises(TypeError):
91        cast_to_tensor('tensor', dtype.float32)
92
93    ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param'))
94    assert isinstance(ans1, Parameter)
95    ans2 = cast_to_tensor(np.array(1.0).astype(np.float32))
96    assert isinstance(ans2, Tensor)
97    ans3 = cast_to_tensor([1.0, 2.0])
98    assert isinstance(ans3, Tensor)
99    ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32)
100    assert isinstance(ans4, Tensor)
101    ans5 = cast_to_tensor(0.1, dtype.float32)
102    assert isinstance(ans5, Tensor)
103    ans6 = cast_to_tensor(1, dtype.float32)
104    assert isinstance(ans6, Tensor)
105
106class Net(Cell):
107    """
108    Test class: CheckTuple.
109    """
110    def __init__(self, value):
111        super(Net, self).__init__()
112        self.checktuple = CheckTuple()
113        self.value = value
114
115    def construct(self, value=None):
116        if value is None:
117            return self.checktuple(self.value, 'input')
118        return self.checktuple(value, 'input')
119
120def test_check_tuple():
121    """
122    Test CheckTuple.
123    """
124    net1 = Net((1, 2, 3))
125    ans1 = net1()
126    assert isinstance(ans1, tuple)
127
128    with pytest.raises(TypeError):
129        net2 = Net('tuple')
130        net2()
131
132    context.set_context(mode=context.GRAPH_MODE)
133    net3 = Net((1, 2, 3))
134    ans3 = net3()
135    assert isinstance(ans3, tuple)
136
137    with pytest.raises(TypeError):
138        net4 = Net('tuple')
139        net4()
140
141class Net1(Cell):
142    """
143    Test class: CheckTensor.
144    """
145    def __init__(self, value):
146        super(Net1, self).__init__()
147        self.checktensor = CheckTensor()
148        self.value = value
149        self.context = context.get_context('mode')
150
151    def construct(self, value=None):
152        value = self.value if value is None else value
153        if self.context == 0:
154            self.checktensor(value, 'input')
155            return value
156        return self.checktensor(value, 'input')
157
158def test_check_tensor():
159    """
160    Test CheckTensor.
161    """
162    value = Tensor(0.1, dtype=dtype.float32)
163    net1 = Net1(value)
164    ans1 = net1()
165    assert isinstance(ans1, Tensor)
166    ans1 = net1(value)
167    assert isinstance(ans1, Tensor)
168
169    with pytest.raises(TypeError):
170        net2 = Net1('tuple')
171        net2()
172
173    context.set_context(mode=context.GRAPH_MODE)
174    net3 = Net1(value)
175    ans3 = net3()
176    assert isinstance(ans3, Tensor)
177    ans3 = net3(value)
178    assert isinstance(ans3, Tensor)
179
180    with pytest.raises(TypeError):
181        net4 = Net1('tuple')
182        net4()
183