1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""test tensor py""" 16import numpy as np 17 18import mindspore as ms 19import mindspore.common.initializer as init 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell 22from mindspore.ops import operations as P 23from ..ut_filter import non_graph_engine 24 25 26def _attribute(tensor, shape_, size_, dtype_): 27 result = (tensor.shape == shape_) and \ 28 (tensor.size == size_) and \ 29 (tensor.dtype == dtype_) 30 return result 31 32 33def test_tensor_init(): 34 nparray = np.ones([2, 2], np.float32) 35 ms.Tensor(nparray) 36 37 ms.Tensor(nparray, dtype=ms.float32) 38 39 40@non_graph_engine 41def test_tensor_add(): 42 a = ms.Tensor(np.ones([3, 3], np.float32)) 43 b = ms.Tensor(np.ones([3, 3], np.float32)) 44 a += b 45 46 47@non_graph_engine 48def test_tensor_sub(): 49 a = ms.Tensor(np.ones([2, 3])) 50 b = ms.Tensor(np.ones([2, 3])) 51 b -= a 52 53 54@non_graph_engine 55def test_tensor_mul(): 56 a = ms.Tensor(np.ones([3, 3])) 57 b = ms.Tensor(np.ones([3, 3])) 58 a *= b 59 60 61def test_tensor_dim(): 62 arr = np.ones((1, 6)) 63 b = ms.Tensor(arr) 64 assert b.ndim == 2 65 66 67def test_tensor_size(): 68 arr = np.ones((1, 6)) 69 b = ms.Tensor(arr) 70 assert arr.size == b.size 71 72 73def test_tensor_itemsize(): 74 arr = np.ones((1, 2, 3)) 75 b = ms.Tensor(arr) 76 assert arr.itemsize == b.itemsize 77 78 79def test_tensor_strides(): 80 arr = np.ones((3, 4, 5, 6)) 81 b = ms.Tensor(arr) 82 assert arr.strides == b.strides 83 84 85def test_tensor_nbytes(): 86 arr = np.ones((3, 4, 5, 6)) 87 b = ms.Tensor(arr) 88 assert arr.nbytes == b.nbytes 89 90 91def test_dtype(): 92 a = ms.Tensor(np.ones((2, 3), dtype=np.int32)) 93 assert a.dtype == ms.int32 94 95 96def test_asnumpy(): 97 npd = np.ones((2, 3)) 98 a = ms.Tensor(npd) 99 a.set_dtype(ms.int32) 100 assert a.asnumpy().all() == npd.all() 101 102 103def test_initializer_asnumpy(): 104 npd = np.ones((2, 3)) 105 a = init.initializer('one', [2, 3], ms.int32) 106 assert a.asnumpy().all() == npd.all() 107 108 109def test_print(): 110 a = ms.Tensor(np.ones((2, 3))) 111 a.set_dtype(ms.int32) 112 print(a) 113 114 115def test_float(): 116 a = ms.Tensor(np.ones((2, 3)), ms.float16) 117 assert a.dtype == ms.float16 118 119 120def test_tensor_method_sub(): 121 """test_tensor_method_sub""" 122 123 class Net(Cell): 124 def __init__(self): 125 super(Net, self).__init__() 126 self.sub = P.Sub() 127 128 def construct(self, x, y): 129 out = x - y 130 return out.transpose() 131 132 net = Net() 133 134 x = ms.Tensor(np.ones([5, 3], np.float32)) 135 y = ms.Tensor(np.ones([8, 5, 3], np.float32)) 136 _cell_graph_executor.compile(net, x, y) 137 138 139def test_tensor_method_mul(): 140 """test_tensor_method_mul""" 141 142 class Net(Cell): 143 def __init__(self): 144 super(Net, self).__init__() 145 self.sub = P.Sub() 146 147 def construct(self, x, y): 148 out = x * (-y) 149 return out.transpose() 150 151 net = Net() 152 153 x = ms.Tensor(np.ones([5, 3], np.float32)) 154 y = ms.Tensor(np.ones([8, 5, 3], np.float32)) 155 _cell_graph_executor.compile(net, x, y) 156 157 158def test_tensor_method_div(): 159 """test_tensor_method_div""" 160 161 class Net(Cell): 162 def __init__(self): 163 super(Net, self).__init__() 164 self.sub = P.Sub() 165 166 def construct(self, x, y): 167 out = x / y 168 return out.transpose() 169 170 net = Net() 171 172 x = ms.Tensor(np.ones([5, 3], np.float32)) 173 y = ms.Tensor(np.ones([8, 5, 3], np.float32)) 174 _cell_graph_executor.compile(net, x, y) 175