1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test tensor properties in graph mode""" 16 17import numpy as np 18 19import mindspore.nn as nn 20import mindspore.common.dtype as mstype 21from mindspore import Tensor 22from mindspore import context 23 24context.set_context(mode=context.GRAPH_MODE) 25 26 27def test_ndim(): 28 class Net(nn.Cell): 29 def __init__(self): 30 super(Net, self).__init__() 31 self.value = Tensor(np.random.random( 32 (2, 3, 4, 5)), dtype=mstype.float32) 33 34 def construct(self): 35 return self.value.ndim 36 37 net = Net() 38 res = net() 39 assert res == 4 40 41 42def test_nbytes(): 43 class Net(nn.Cell): 44 def __init__(self): 45 super(Net, self).__init__() 46 self.value = Tensor(np.random.random( 47 (2, 3, 4, 5)), dtype=mstype.float32) 48 49 def construct(self): 50 return self.value.nbytes 51 52 net = Net() 53 res = net() 54 assert res == 480 55 56 57def test_size(): 58 class Net(nn.Cell): 59 def __init__(self): 60 super(Net, self).__init__() 61 self.value = Tensor(np.random.random( 62 (2, 3, 4, 5)), dtype=mstype.float32) 63 64 def construct(self): 65 return self.value.size 66 67 net = Net() 68 res = net() 69 assert res == 120 70 71 72def test_strides(): 73 class Net(nn.Cell): 74 def __init__(self): 75 super(Net, self).__init__() 76 self.value = Tensor(np.random.random( 77 (2, 3, 4, 5)), dtype=mstype.float32) 78 79 def construct(self): 80 return self.value.strides 81 82 net = Net() 83 res = net() 84 assert res == (240, 80, 20, 4) 85 86 87def test_itemsize(): 88 class Net(nn.Cell): 89 def __init__(self): 90 super(Net, self).__init__() 91 self.value1 = Tensor(np.random.random( 92 (2, 3, 4, 5)), dtype=mstype.float64) 93 self.value2 = Tensor(np.random.random( 94 (2, 3, 4, 5)), dtype=mstype.int32) 95 self.value3 = Tensor(np.random.random( 96 (2, 3, 4, 5)), dtype=mstype.bool_) 97 98 def construct(self): 99 return (self.value1.itemsize, self.value2.itemsize, self.value3.itemsize) 100 101 net = Net() 102 res = net() 103 assert res == (8, 4, 1) 104