• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15"""multitype_ops directory test case"""
16import numpy as np
17import pytest
18
19import mindspore.nn as nn
20from mindspore import Tensor
21from mindspore import dtype as mstype
22from mindspore.ops import functional as F
23import mindspore.context as context
24
25
26class TensorIntAutoCast(nn.Cell):
27    def __init__(self,):
28        super(TensorIntAutoCast, self).__init__()
29        self.i = 2
30
31    def construct(self, t):
32        z = F.tensor_mul(t, self.i)
33        return z
34
35
36class TensorFPAutoCast(nn.Cell):
37    def __init__(self,):
38        super(TensorFPAutoCast, self).__init__()
39        self.f = 1.2
40
41    def construct(self, t):
42        z = F.tensor_mul(t, self.f)
43        return z
44
45
46class TensorBoolAutoCast(nn.Cell):
47    def __init__(self,):
48        super(TensorBoolAutoCast, self).__init__()
49        self.f = True
50
51    def construct(self, t):
52        z = F.tensor_mul(t, self.f)
53        return z
54
55
56class TensorAutoCast(nn.Cell):
57    def __init__(self,):
58        super(TensorAutoCast, self).__init__()
59
60    def construct(self, t1, t2):
61        z = F.tensor_mul(t1, t2)
62        return z
63
64
65def test_tensor_auto_cast():
66    context.set_context(mode=context.GRAPH_MODE)
67    Tensor([True, False], mstype.bool_)
68    t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
69    t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
70    t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
71    t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
72    t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64)
73    t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
74    t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
75    t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
76    net = TensorAutoCast()
77    rs = net(t_uint8, t_int8)
78    assert rs.dtype == mstype.int16
79    rs = net(t_uint8, t_int16)
80    assert rs.dtype == mstype.int16
81    rs = net(t_uint8, t_int32)
82    assert rs.dtype == mstype.int32
83    rs = net(t_uint8, t_int64)
84    assert rs.dtype == mstype.int64
85    rs = net(t_int8, t_int16)
86    assert rs.dtype == mstype.int16
87    rs = net(t_int8, t_int32)
88    assert rs.dtype == mstype.int32
89    rs = net(t_int8, t_int64)
90    assert rs.dtype == mstype.int64
91    rs = net(t_int16, t_int32)
92    assert rs.dtype == mstype.int32
93    rs = net(t_int16, t_int64)
94    assert rs.dtype == mstype.int64
95    rs = net(t_int32, t_int64)
96    assert rs.dtype == mstype.int64
97
98    rs = net(t_fp16, t_fp32)
99    assert rs.dtype == mstype.float32
100    rs = net(t_fp16, t_fp64)
101    assert rs.dtype == mstype.float64
102    rs = net(t_fp32, t_fp64)
103    assert rs.dtype == mstype.float64
104
105    rs = net(t_uint8, t_fp16)
106    assert rs.dtype == mstype.float16
107    rs = net(t_uint8, t_fp32)
108    assert rs.dtype == mstype.float32
109    rs = net(t_uint8, t_fp64)
110    assert rs.dtype == mstype.float64
111    rs = net(t_int8, t_fp64)
112    assert rs.dtype == mstype.float64
113    rs = net(t_int16, t_fp64)
114    assert rs.dtype == mstype.float64
115    rs = net(t_int32, t_fp64)
116    assert rs.dtype == mstype.float64
117    rs = net(t_int64, t_fp64)
118    assert rs.dtype == mstype.float64
119
120    rs = net(t_fp16, t_int8)
121    assert rs.dtype == mstype.float16
122    rs = net(t_fp16, t_uint8)
123    assert rs.dtype == mstype.float16
124    rs = net(t_fp16, t_int16)
125    assert rs.dtype == mstype.float16
126    rs = net(t_fp16, t_int32)
127    assert rs.dtype == mstype.float16
128    rs = net(t_fp16, t_int64)
129    assert rs.dtype == mstype.float16
130
131    tint = TensorIntAutoCast()
132    rs = tint(t_uint8)
133    assert rs.dtype == mstype.uint8
134    rs = tint(t_int8)
135    assert rs.dtype == mstype.int8
136    rs = tint(t_int16)
137    assert rs.dtype == mstype.int16
138    rs = tint(t_int32)
139    assert rs.dtype == mstype.int32
140    rs = tint(t_int64)
141    assert rs.dtype == mstype.int64
142    rs = tint(t_fp16)
143    assert rs.dtype == mstype.float16
144    rs = tint(t_fp32)
145    assert rs.dtype == mstype.float32
146    rs = tint(t_fp64)
147    assert rs.dtype == mstype.float64
148    tfp = TensorFPAutoCast()
149    rs = tfp(t_uint8)
150    assert rs.dtype == mstype.float32
151    rs = tfp(t_int8)
152    assert rs.dtype == mstype.float32
153    rs = tfp(t_int16)
154    assert rs.dtype == mstype.float32
155    rs = tfp(t_int32)
156    assert rs.dtype == mstype.float32
157    rs = tfp(t_int64)
158    assert rs.dtype == mstype.float32
159    rs = tfp(t_fp16)
160    assert rs.dtype == mstype.float32
161    rs = tfp(t_fp32)
162    assert rs.dtype == mstype.float32
163    rs = tfp(t_fp64)
164    assert rs.dtype == mstype.float64
165
166    t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
167    t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)
168    t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64)
169    with pytest.raises(TypeError):
170        net(t_uint16, t_uint8)
171    with pytest.raises(TypeError):
172        net(t_uint16, t_int8)
173    with pytest.raises(TypeError):
174        net(t_uint16, t_int16)
175    with pytest.raises(TypeError):
176        net(t_uint16, t_int32)
177    with pytest.raises(TypeError):
178        net(t_uint16, t_int64)
179    with pytest.raises(TypeError):
180        net(t_uint32, t_uint8)
181    with pytest.raises(TypeError):
182        net(t_uint32, t_int8)
183    with pytest.raises(TypeError):
184        net(t_uint32, t_int16)
185    with pytest.raises(TypeError):
186        net(t_uint32, t_int32)
187    with pytest.raises(TypeError):
188        net(t_uint32, t_int64)
189    with pytest.raises(TypeError):
190        net(t_uint64, t_uint8)
191    with pytest.raises(TypeError):
192        net(t_uint64, t_int8)
193    with pytest.raises(TypeError):
194        net(t_uint64, t_int16)
195    with pytest.raises(TypeError):
196        net(t_uint64, t_int32)
197    with pytest.raises(TypeError):
198        net(t_uint64, t_int64)
199    with pytest.raises(TypeError):
200        net(t_uint16, t_fp16)
201    with pytest.raises(TypeError):
202        net(t_uint16, t_fp32)
203    with pytest.raises(TypeError):
204        net(t_uint16, t_fp64)
205    with pytest.raises(TypeError):
206        net(t_uint32, t_fp16)
207    with pytest.raises(TypeError):
208        net(t_uint32, t_fp32)
209    with pytest.raises(TypeError):
210        net(t_uint32, t_fp64)
211    with pytest.raises(TypeError):
212        net(t_uint64, t_fp16)
213    with pytest.raises(TypeError):
214        net(t_uint64, t_fp32)
215    with pytest.raises(TypeError):
216        net(t_uint64, t_fp64)
217
218    with pytest.raises(TypeError):
219        tfp(t_uint16)
220    with pytest.raises(TypeError):
221        tfp(t_uint32)
222    with pytest.raises(TypeError):
223        tfp(t_uint64)
224
225    with pytest.raises(TypeError):
226        tint(t_uint16)
227    with pytest.raises(TypeError):
228        tint(t_uint32)
229    with pytest.raises(TypeError):
230        tint(t_uint64)
231
232    bnet = TensorBoolAutoCast()
233    with pytest.raises(TypeError):
234        bnet(t_uint8)
235    with pytest.raises(TypeError):
236        bnet(t_int8)
237    with pytest.raises(TypeError):
238        bnet(t_int16)
239    with pytest.raises(TypeError):
240        bnet(t_int32)
241    with pytest.raises(TypeError):
242        bnet(t_int64)
243    with pytest.raises(TypeError):
244        bnet(t_fp16)
245    with pytest.raises(TypeError):
246        bnet(t_fp32)
247    with pytest.raises(TypeError):
248        bnet(t_fp64)
249def test_bool_tensor_and_float():
250    context.set_context(mode=context.GRAPH_MODE)
251    t_bool = Tensor(np.ones([2, 1, 2, 2]).astype(np.bool), mstype.bool_)
252    t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
253    t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
254    t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
255    net = TensorFPAutoCast()
256    out = net(t_bool)
257    assert out.dtype == mstype.float32
258    net = TensorIntAutoCast()
259    out = net(t_bool)
260    assert out.dtype == mstype.int32
261    out = net(t_fp16)
262    assert out.dtype == mstype.float16
263    out = net(t_fp32)
264    assert out.dtype == mstype.float32
265    out = net(t_int32)
266    assert out.dtype == mstype.int32
267