1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" 16Test util functions used in distribution classes. 17""" 18import numpy as np 19import pytest 20 21from mindspore.nn.cell import Cell 22from mindspore import context 23from mindspore import dtype 24from mindspore import Tensor 25from mindspore.common.parameter import Parameter 26from mindspore.nn.probability.distribution._utils.utils import set_param_type, \ 27 cast_to_tensor, CheckTuple, CheckTensor 28 29def test_set_param_type(): 30 """ 31 Test set_param_type function. 32 """ 33 tensor_fp16 = Tensor(0.1, dtype=dtype.float16) 34 tensor_fp32 = Tensor(0.1, dtype=dtype.float32) 35 tensor_fp64 = Tensor(0.1, dtype=dtype.float64) 36 tensor_int32 = Tensor(0.1, dtype=dtype.int32) 37 array_fp32 = np.array(1.0).astype(np.float32) 38 array_fp64 = np.array(1.0).astype(np.float64) 39 array_int32 = np.array(1.0).astype(np.int32) 40 41 dict1 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp32} 42 dict2 = {'a': tensor_fp32, 'b': 1.0, 'c': tensor_fp64} 43 dict3 = {'a': tensor_int32, 'b': 1.0, 'c': tensor_int32} 44 dict4 = {'a': array_fp32, 'b': 1.0, 'c': tensor_fp32} 45 dict5 = {'a': array_fp32, 'b': 1.0, 'c': array_fp64} 46 dict6 = {'a': array_fp32, 'b': 1.0, 'c': array_int32} 47 dict7 = {'a': 1.0} 48 dict8 = {'a': 1.0, 'b': 1.0, 'c': 1.0} 49 dict9 = {'a': tensor_fp16, 'b': tensor_fp16, 'c': tensor_fp16} 50 dict10 = {'a': tensor_fp64, 'b': tensor_fp64, 'c': tensor_fp64} 51 dict11 = {'a': array_fp64, 'b': array_fp64, 'c': tensor_fp64} 52 53 ans1 = set_param_type(dict1, dtype.float16) 54 assert ans1 == dtype.float32 55 56 with pytest.raises(TypeError): 57 set_param_type(dict2, dtype.float32) 58 59 ans3 = set_param_type(dict3, dtype.float16) 60 assert ans3 == dtype.float32 61 ans4 = set_param_type(dict4, dtype.float16) 62 assert ans4 == dtype.float32 63 64 with pytest.raises(TypeError): 65 set_param_type(dict5, dtype.float32) 66 with pytest.raises(TypeError): 67 set_param_type(dict6, dtype.float32) 68 69 ans7 = set_param_type(dict7, dtype.float32) 70 assert ans7 == dtype.float32 71 ans8 = set_param_type(dict8, dtype.float32) 72 assert ans8 == dtype.float32 73 ans9 = set_param_type(dict9, dtype.float32) 74 assert ans9 == dtype.float16 75 ans10 = set_param_type(dict10, dtype.float32) 76 assert ans10 == dtype.float32 77 ans11 = set_param_type(dict11, dtype.float32) 78 assert ans11 == dtype.float32 79 80def test_cast_to_tensor(): 81 """ 82 Test cast_to_tensor. 83 """ 84 with pytest.raises(ValueError): 85 cast_to_tensor(None, dtype.float32) 86 with pytest.raises(TypeError): 87 cast_to_tensor(True, dtype.float32) 88 with pytest.raises(TypeError): 89 cast_to_tensor({'a': 1, 'b': 2}, dtype.float32) 90 with pytest.raises(TypeError): 91 cast_to_tensor('tensor', dtype.float32) 92 93 ans1 = cast_to_tensor(Parameter(Tensor(0.1, dtype=dtype.float32), 'param')) 94 assert isinstance(ans1, Parameter) 95 ans2 = cast_to_tensor(np.array(1.0).astype(np.float32)) 96 assert isinstance(ans2, Tensor) 97 ans3 = cast_to_tensor([1.0, 2.0]) 98 assert isinstance(ans3, Tensor) 99 ans4 = cast_to_tensor(Tensor(0.1, dtype=dtype.float32), dtype.float32) 100 assert isinstance(ans4, Tensor) 101 ans5 = cast_to_tensor(0.1, dtype.float32) 102 assert isinstance(ans5, Tensor) 103 ans6 = cast_to_tensor(1, dtype.float32) 104 assert isinstance(ans6, Tensor) 105 106class Net(Cell): 107 """ 108 Test class: CheckTuple. 109 """ 110 def __init__(self, value): 111 super(Net, self).__init__() 112 self.checktuple = CheckTuple() 113 self.value = value 114 115 def construct(self, value=None): 116 if value is None: 117 return self.checktuple(self.value, 'input') 118 return self.checktuple(value, 'input') 119 120def test_check_tuple(): 121 """ 122 Test CheckTuple. 123 """ 124 net1 = Net((1, 2, 3)) 125 ans1 = net1() 126 assert isinstance(ans1, tuple) 127 128 with pytest.raises(TypeError): 129 net2 = Net('tuple') 130 net2() 131 132 context.set_context(mode=context.GRAPH_MODE) 133 net3 = Net((1, 2, 3)) 134 ans3 = net3() 135 assert isinstance(ans3, tuple) 136 137 with pytest.raises(TypeError): 138 net4 = Net('tuple') 139 net4() 140 141class Net1(Cell): 142 """ 143 Test class: CheckTensor. 144 """ 145 def __init__(self, value): 146 super(Net1, self).__init__() 147 self.checktensor = CheckTensor() 148 self.value = value 149 self.context = context.get_context('mode') 150 151 def construct(self, value=None): 152 value = self.value if value is None else value 153 if self.context == 0: 154 self.checktensor(value, 'input') 155 return value 156 return self.checktensor(value, 'input') 157 158def test_check_tensor(): 159 """ 160 Test CheckTensor. 161 """ 162 value = Tensor(0.1, dtype=dtype.float32) 163 net1 = Net1(value) 164 ans1 = net1() 165 assert isinstance(ans1, Tensor) 166 ans1 = net1(value) 167 assert isinstance(ans1, Tensor) 168 169 with pytest.raises(TypeError): 170 net2 = Net1('tuple') 171 net2() 172 173 context.set_context(mode=context.GRAPH_MODE) 174 net3 = Net1(value) 175 ans3 = net3() 176 assert isinstance(ans3, Tensor) 177 ans3 = net3(value) 178 assert isinstance(ans3, Tensor) 179 180 with pytest.raises(TypeError): 181 net4 = Net1('tuple') 182 net4() 183