1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test_tuple_slice """ 16import numpy as np 17 18import mindspore.ops.operations as P 19from mindspore import Tensor 20from mindspore.nn import Cell 21from ....mindspore_test_framework.mindspore_test import mindspore_test 22from ....mindspore_test_framework.pipeline.forward.compile_forward \ 23 import pipeline_for_compile_forward_ge_graph_for_case_by_case_config 24from ....mindspore_test_framework.pipeline.forward.verify_exception \ 25 import pipeline_for_verify_exception_for_case_by_case_config 26 27 28class NetWork_1(Cell): 29 """ NetWork_1 definition """ 30 31 def __init__(self): 32 super(NetWork_1, self).__init__() 33 self.addN = P.AddN() 34 self.index_0 = Tensor(3) 35 self.index_1 = Tensor([5]) 36 self.index_3 = Tensor([True]) 37 38 def construct(self, tensor_tuple): 39 tensor_tuple_slice0 = tensor_tuple[:] 40 tensor_tuple_slice1 = tensor_tuple[:self.index_0] 41 tensor_tuple_slice2 = tensor_tuple[self.index_3:] 42 tensor_tuple_slice3 = tensor_tuple[2:self.index_1:True] 43 sum0 = self.addN(tensor_tuple_slice0) 44 sum1 = self.addN(tensor_tuple_slice1) 45 sum2 = self.addN(tensor_tuple_slice2) 46 sum3 = self.addN(tensor_tuple_slice3) 47 ret = sum0 + sum1 + sum2 + sum3 48 return ret 49 50 51class NetWork_2(Cell): 52 """ NetWork_2 definition """ 53 54 def __init__(self): 55 super(NetWork_2, self).__init__() 56 self.addN = P.AddN() 57 self.step = Tensor([-1]) 58 self.index_0 = Tensor(-6) 59 60 def construct(self, tensor_tuple): 61 tensor_tuple_slice0 = tensor_tuple[::self.step] 62 tensor_tuple_slice1 = tensor_tuple[-1::-1] 63 tensor_tuple_slice2 = tensor_tuple[:-4:-1] 64 tensor_tuple_slice3 = tensor_tuple[self.index_0:3] 65 tensor_tuple_slice4 = tensor_tuple[-1:-6:-2] 66 sum0 = self.addN(tensor_tuple_slice0) 67 sum1 = self.addN(tensor_tuple_slice1) 68 sum2 = self.addN(tensor_tuple_slice2) 69 sum3 = self.addN(tensor_tuple_slice3) 70 sum4 = self.addN(tensor_tuple_slice4) 71 ret = sum0 + sum1 + sum2 + sum3 + sum4 72 return ret 73 74 75class NetWorkSliceStepZero(Cell): 76 """ NetWork_3 definition """ 77 78 def __init__(self): 79 super(NetWorkSliceStepZero, self).__init__() 80 81 def construct(self, tensor_tuple): 82 tensor_tuple_slice = tensor_tuple[0:3:0] 83 return tensor_tuple_slice 84 85 86class NetWorkOutOfBounds(Cell): 87 """ NetWork_3 definition """ 88 89 def __init__(self): 90 super(NetWorkOutOfBounds, self).__init__() 91 92 def construct(self, tensor_tuple): 93 return tensor_tuple[100] 94 95 96class NetWorkTensorSizeGreaterThanTwo(Cell): 97 """ NetWork_3 definition """ 98 99 def __init__(self): 100 super(NetWorkTensorSizeGreaterThanTwo, self).__init__() 101 self.index_0 = Tensor([2, 3]) 102 103 def construct(self, tensor_tuple): 104 return tensor_tuple[1:self.index_0] 105 106 107class NetWorkTensorDtypeFloat(Cell): 108 """ NetWork_3 definition """ 109 110 def __init__(self): 111 super(NetWorkTensorDtypeFloat, self).__init__() 112 self.index_0 = Tensor([2.1]) 113 114 def construct(self, tensor_tuple): 115 return tensor_tuple[1:self.index_0] 116 117 118test_cases = [ 119 ('SlicePositive', { 120 'block': NetWork_1(), 121 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 122 Tensor(np.zeros([2, 3, 4], np.int32)), 123 Tensor(np.ones([2, 3, 4], np.int32)), 124 Tensor(np.ones([2, 3, 4], np.int32)), 125 Tensor(np.zeros([2, 3, 4], np.int32)), 126 Tensor(np.ones([2, 3, 4], np.int32)))], 127 }), 128 ('SliceNegative', { 129 'block': NetWork_2(), 130 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 131 Tensor(np.zeros([2, 3, 4], np.int32)), 132 Tensor(np.ones([2, 3, 4], np.int32)), 133 Tensor(np.ones([2, 3, 4], np.int32)), 134 Tensor(np.zeros([2, 3, 4], np.int32)), 135 Tensor(np.ones([2, 3, 4], np.int32)))], 136 }), 137] 138 139test_cases_for_verify_exception = [ 140 ('SliceStepZero', { 141 'block': (NetWorkSliceStepZero(), {'exception': ValueError}), 142 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 143 Tensor(np.zeros([2, 3, 4], np.int32)), 144 Tensor(np.ones([2, 3, 4], np.int32)))], 145 }), 146 ('SliceOutOfBounds', { 147 'block': (NetWorkOutOfBounds(), {'exception': IndexError}), 148 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 149 Tensor(np.zeros([2, 3, 4], np.int32)), 150 Tensor(np.ones([2, 3, 4], np.int32)))], 151 }), 152 ('SliceTensorSizeGreaterThanTwo', { 153 'block': (NetWorkTensorSizeGreaterThanTwo(), {'exception': TypeError}), 154 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 155 Tensor(np.zeros([2, 3, 4], np.int32)), 156 Tensor(np.ones([2, 3, 4], np.int32)))], 157 }), 158 ('SliceTensorDtypeFloat', { 159 'block': (NetWorkTensorDtypeFloat(), {'exception': TypeError}), 160 'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)), 161 Tensor(np.zeros([2, 3, 4], np.int32)), 162 Tensor(np.ones([2, 3, 4], np.int32)))], 163 }), 164] 165 166 167@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) 168def test_compile(): 169 return test_cases 170 171 172@mindspore_test(pipeline_for_verify_exception_for_case_by_case_config) 173def test_check_exception(): 174 return test_cases_for_verify_exception 175