1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""multitype_ops directory test case""" 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import dtype as mstype 22from mindspore.ops import functional as F 23import mindspore.context as context 24 25 26class TensorIntAutoCast(nn.Cell): 27 def __init__(self,): 28 super(TensorIntAutoCast, self).__init__() 29 self.i = 2 30 31 def construct(self, t): 32 z = F.tensor_mul(t, self.i) 33 return z 34 35 36class TensorFPAutoCast(nn.Cell): 37 def __init__(self,): 38 super(TensorFPAutoCast, self).__init__() 39 self.f = 1.2 40 41 def construct(self, t): 42 z = F.tensor_mul(t, self.f) 43 return z 44 45 46class TensorBoolAutoCast(nn.Cell): 47 def __init__(self,): 48 super(TensorBoolAutoCast, self).__init__() 49 self.f = True 50 51 def construct(self, t): 52 z = F.tensor_mul(t, self.f) 53 return z 54 55 56class TensorAutoCast(nn.Cell): 57 def __init__(self,): 58 super(TensorAutoCast, self).__init__() 59 60 def construct(self, t1, t2): 61 z = F.tensor_mul(t1, t2) 62 return z 63 64 65def test_tensor_auto_cast(): 66 context.set_context(mode=context.GRAPH_MODE) 67 Tensor([True, False], mstype.bool_) 68 t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8) 69 t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8) 70 t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16) 71 t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32) 72 t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64) 73 t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16) 74 t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32) 75 t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64) 76 net = TensorAutoCast() 77 rs = net(t_uint8, t_int8) 78 assert rs.dtype == mstype.int16 79 rs = net(t_uint8, t_int16) 80 assert rs.dtype == mstype.int16 81 rs = net(t_uint8, t_int32) 82 assert rs.dtype == mstype.int32 83 rs = net(t_uint8, t_int64) 84 assert rs.dtype == mstype.int64 85 rs = net(t_int8, t_int16) 86 assert rs.dtype == mstype.int16 87 rs = net(t_int8, t_int32) 88 assert rs.dtype == mstype.int32 89 rs = net(t_int8, t_int64) 90 assert rs.dtype == mstype.int64 91 rs = net(t_int16, t_int32) 92 assert rs.dtype == mstype.int32 93 rs = net(t_int16, t_int64) 94 assert rs.dtype == mstype.int64 95 rs = net(t_int32, t_int64) 96 assert rs.dtype == mstype.int64 97 98 rs = net(t_fp16, t_fp32) 99 assert rs.dtype == mstype.float32 100 rs = net(t_fp16, t_fp64) 101 assert rs.dtype == mstype.float64 102 rs = net(t_fp32, t_fp64) 103 assert rs.dtype == mstype.float64 104 105 rs = net(t_uint8, t_fp16) 106 assert rs.dtype == mstype.float16 107 rs = net(t_uint8, t_fp32) 108 assert rs.dtype == mstype.float32 109 rs = net(t_uint8, t_fp64) 110 assert rs.dtype == mstype.float64 111 rs = net(t_int8, t_fp64) 112 assert rs.dtype == mstype.float64 113 rs = net(t_int16, t_fp64) 114 assert rs.dtype == mstype.float64 115 rs = net(t_int32, t_fp64) 116 assert rs.dtype == mstype.float64 117 rs = net(t_int64, t_fp64) 118 assert rs.dtype == mstype.float64 119 120 rs = net(t_fp16, t_int8) 121 assert rs.dtype == mstype.float16 122 rs = net(t_fp16, t_uint8) 123 assert rs.dtype == mstype.float16 124 rs = net(t_fp16, t_int16) 125 assert rs.dtype == mstype.float16 126 rs = net(t_fp16, t_int32) 127 assert rs.dtype == mstype.float16 128 rs = net(t_fp16, t_int64) 129 assert rs.dtype == mstype.float16 130 131 tint = TensorIntAutoCast() 132 rs = tint(t_uint8) 133 assert rs.dtype == mstype.uint8 134 rs = tint(t_int8) 135 assert rs.dtype == mstype.int8 136 rs = tint(t_int16) 137 assert rs.dtype == mstype.int16 138 rs = tint(t_int32) 139 assert rs.dtype == mstype.int32 140 rs = tint(t_int64) 141 assert rs.dtype == mstype.int64 142 rs = tint(t_fp16) 143 assert rs.dtype == mstype.float16 144 rs = tint(t_fp32) 145 assert rs.dtype == mstype.float32 146 rs = tint(t_fp64) 147 assert rs.dtype == mstype.float64 148 tfp = TensorFPAutoCast() 149 rs = tfp(t_uint8) 150 assert rs.dtype == mstype.float32 151 rs = tfp(t_int8) 152 assert rs.dtype == mstype.float32 153 rs = tfp(t_int16) 154 assert rs.dtype == mstype.float32 155 rs = tfp(t_int32) 156 assert rs.dtype == mstype.float32 157 rs = tfp(t_int64) 158 assert rs.dtype == mstype.float32 159 rs = tfp(t_fp16) 160 assert rs.dtype == mstype.float32 161 rs = tfp(t_fp32) 162 assert rs.dtype == mstype.float32 163 rs = tfp(t_fp64) 164 assert rs.dtype == mstype.float64 165 166 t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16) 167 t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32) 168 t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64) 169 with pytest.raises(TypeError): 170 net(t_uint16, t_uint8) 171 with pytest.raises(TypeError): 172 net(t_uint16, t_int8) 173 with pytest.raises(TypeError): 174 net(t_uint16, t_int16) 175 with pytest.raises(TypeError): 176 net(t_uint16, t_int32) 177 with pytest.raises(TypeError): 178 net(t_uint16, t_int64) 179 with pytest.raises(TypeError): 180 net(t_uint32, t_uint8) 181 with pytest.raises(TypeError): 182 net(t_uint32, t_int8) 183 with pytest.raises(TypeError): 184 net(t_uint32, t_int16) 185 with pytest.raises(TypeError): 186 net(t_uint32, t_int32) 187 with pytest.raises(TypeError): 188 net(t_uint32, t_int64) 189 with pytest.raises(TypeError): 190 net(t_uint64, t_uint8) 191 with pytest.raises(TypeError): 192 net(t_uint64, t_int8) 193 with pytest.raises(TypeError): 194 net(t_uint64, t_int16) 195 with pytest.raises(TypeError): 196 net(t_uint64, t_int32) 197 with pytest.raises(TypeError): 198 net(t_uint64, t_int64) 199 with pytest.raises(TypeError): 200 net(t_uint16, t_fp16) 201 with pytest.raises(TypeError): 202 net(t_uint16, t_fp32) 203 with pytest.raises(TypeError): 204 net(t_uint16, t_fp64) 205 with pytest.raises(TypeError): 206 net(t_uint32, t_fp16) 207 with pytest.raises(TypeError): 208 net(t_uint32, t_fp32) 209 with pytest.raises(TypeError): 210 net(t_uint32, t_fp64) 211 with pytest.raises(TypeError): 212 net(t_uint64, t_fp16) 213 with pytest.raises(TypeError): 214 net(t_uint64, t_fp32) 215 with pytest.raises(TypeError): 216 net(t_uint64, t_fp64) 217 218 with pytest.raises(TypeError): 219 tfp(t_uint16) 220 with pytest.raises(TypeError): 221 tfp(t_uint32) 222 with pytest.raises(TypeError): 223 tfp(t_uint64) 224 225 with pytest.raises(TypeError): 226 tint(t_uint16) 227 with pytest.raises(TypeError): 228 tint(t_uint32) 229 with pytest.raises(TypeError): 230 tint(t_uint64) 231 232 bnet = TensorBoolAutoCast() 233 with pytest.raises(TypeError): 234 bnet(t_uint8) 235 with pytest.raises(TypeError): 236 bnet(t_int8) 237 with pytest.raises(TypeError): 238 bnet(t_int16) 239 with pytest.raises(TypeError): 240 bnet(t_int32) 241 with pytest.raises(TypeError): 242 bnet(t_int64) 243 with pytest.raises(TypeError): 244 bnet(t_fp16) 245 with pytest.raises(TypeError): 246 bnet(t_fp32) 247 with pytest.raises(TypeError): 248 bnet(t_fp64) 249def test_bool_tensor_and_float(): 250 context.set_context(mode=context.GRAPH_MODE) 251 t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_) 252 t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32) 253 t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16) 254 t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32) 255 net = TensorFPAutoCast() 256 out = net(t_bool) 257 assert out.dtype == mstype.float32 258 net = TensorIntAutoCast() 259 out = net(t_bool) 260 assert out.dtype == mstype.int32 261 out = net(t_fp16) 262 assert out.dtype == mstype.float16 263 out = net(t_fp32) 264 assert out.dtype == mstype.float32 265 out = net(t_int32) 266 assert out.dtype == mstype.int32 267