1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test enumerate""" 16import numpy as np 17import pytest 18 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22 23context.set_context(mode=context.GRAPH_MODE) 24 25 26def test_enumerate_list_const(): 27 class Net(nn.Cell): 28 def __init__(self): 29 super(Net, self).__init__() 30 self.value = [11, 22, 33, 44] 31 32 def construct(self): 33 index_sum = 0 34 value_sum = 0 35 for i, j in enumerate(self.value): 36 index_sum += i 37 value_sum += j 38 return index_sum, value_sum 39 40 net = Net() 41 assert net() == (6, 110) 42 43 44def test_enumerate_tuple_const(): 45 class Net(nn.Cell): 46 def __init__(self): 47 super(Net, self).__init__() 48 self.value = (11, 22, 33, 44) 49 50 def construct(self): 51 index_sum = 0 52 value_sum = 0 53 for i, j in enumerate(self.value): 54 index_sum += i 55 value_sum += j 56 return index_sum, value_sum 57 58 net = Net() 59 assert net() == (6, 110) 60 61 62def test_enumerate_tensor_const(): 63 class Net(nn.Cell): 64 def __init__(self): 65 super(Net, self).__init__() 66 self.value = Tensor(np.arange(2 * 3).reshape(2, 3)) 67 68 def construct(self): 69 return enumerate(self.value) 70 71 net = Net() 72 net() 73 74 75def test_enumerate_list_parameter(): 76 class Net(nn.Cell): 77 def __init__(self): 78 super(Net, self).__init__() 79 80 def construct(self, x, y): 81 index_sum = 0 82 value = [x, y] 83 ret = () 84 for i, j in enumerate(value): 85 index_sum += i 86 ret += (j,) 87 return index_sum, ret 88 89 x = Tensor(np.arange(4)) 90 net = Net() 91 net(x, x) 92 93 94def test_enumerate_tuple_parameter(): 95 class Net(nn.Cell): 96 def __init__(self): 97 super(Net, self).__init__() 98 99 def construct(self, x, y): 100 index_sum = 0 101 value = (x, y) 102 ret = () 103 for i, j in enumerate(value): 104 index_sum += i 105 ret += (j,) 106 return index_sum, ret 107 108 x = Tensor(np.arange(4)) 109 net = Net() 110 net(x, x) 111 112 113def test_enumerate_tensor_parameter(): 114 class Net(nn.Cell): 115 def __init__(self): 116 super(Net, self).__init__() 117 118 def construct(self, x): 119 index_sum = 0 120 ret = () 121 for i, j in enumerate(x): 122 index_sum += i 123 ret += (j,) 124 return index_sum, ret 125 126 x = Tensor(np.arange(2 * 3).reshape(2, 3)) 127 net = Net() 128 net(x) 129 130 131def test_enumerate_tuple_const_1(): 132 class Net(nn.Cell): 133 def __init__(self): 134 super(Net, self).__init__() 135 self.value = (11, 22, 33, 44) 136 137 def construct(self): 138 index_sum = 0 139 value_sum = 0 140 for i in enumerate(self.value): 141 index_sum += i[0] 142 value_sum += i[1] 143 return index_sum, value_sum 144 145 net = Net() 146 assert net() == (6, 110) 147 148 149def test_enumerate_tensor_const_1(): 150 class Net(nn.Cell): 151 def __init__(self): 152 super(Net, self).__init__() 153 self.value = Tensor(np.arange(2*3).reshape(2, 3)) 154 155 def construct(self): 156 index_sum = 0 157 ret = () 158 for i in enumerate(self.value): 159 index_sum += i[0] 160 ret += (i[1],) 161 return index_sum, ret 162 163 net = Net() 164 net() 165 166 167def test_enumerate_tuple_parameter_1(): 168 class Net(nn.Cell): 169 def __init__(self): 170 super(Net, self).__init__() 171 172 def construct(self, x, y): 173 index_sum = 0 174 value = (x, y) 175 ret = () 176 for i in enumerate(value): 177 index_sum += i[0] 178 ret += (i[1],) 179 return index_sum, ret 180 181 x = Tensor(np.arange(4)) 182 net = Net() 183 net(x, x) 184 185 186def test_enumerate_tensor_parameter_1(): 187 class Net(nn.Cell): 188 def __init__(self): 189 super(Net, self).__init__() 190 191 def construct(self, x): 192 index_sum = 0 193 ret = () 194 for i in enumerate(x): 195 index_sum += i[0] 196 ret += (i[1],) 197 return index_sum, ret 198 199 x = Tensor(np.arange(2 * 3).reshape(2, 3)) 200 net = Net() 201 net(x) 202 203 204def test_enumerate_tuple_const_2(): 205 class Net(nn.Cell): 206 def __init__(self): 207 super(Net, self).__init__() 208 self.value = (11, 22, 33, 44) 209 210 def construct(self): 211 index_sum = 0 212 value_sum = 0 213 for i in enumerate(self.value, 1): 214 index_sum += i[0] 215 value_sum += i[1] 216 return index_sum, value_sum 217 218 net = Net() 219 assert net() == (10, 110) 220 221 222def test_enumerate_tensor_const_2(): 223 class Net(nn.Cell): 224 def __init__(self): 225 super(Net, self).__init__() 226 self.value = Tensor(np.arange(2 * 3).reshape(2, 3)) 227 228 def construct(self): 229 index_sum = 0 230 ret = () 231 for i in enumerate(self.value, 1): 232 index_sum += i[0] 233 ret += (i[1],) 234 return index_sum, ret 235 236 net = Net() 237 net() 238 239 240def test_enumerate_tuple_parameter_2(): 241 class Net(nn.Cell): 242 def __init__(self): 243 super(Net, self).__init__() 244 245 def construct(self, x, y): 246 index_sum = 0 247 value = (x, y) 248 ret = () 249 for i in enumerate(value, 1): 250 index_sum += i[0] 251 ret += (i[1],) 252 return index_sum, ret 253 254 x = Tensor(np.arange(4)) 255 net = Net() 256 net(x, x) 257 258 259def test_enumerate_tensor_parameter_2(): 260 class Net(nn.Cell): 261 def __init__(self): 262 super(Net, self).__init__() 263 264 def construct(self, x): 265 index_sum = 0 266 ret = () 267 for i, j in enumerate(x, 1): 268 index_sum += i 269 ret += (j,) 270 return index_sum, ret 271 272 x = Tensor(np.arange(2 * 3).reshape(2, 3)) 273 net = Net() 274 net(x) 275 276 277def test_enumerate_start_type_error(): 278 class Net(nn.Cell): 279 def __init__(self): 280 super(Net, self).__init__() 281 282 def construct(self, x): 283 return enumerate((x, x), start=1.2) 284 285 x = Tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5))) 286 net = Net() 287 with pytest.raises(TypeError) as ex: 288 net(x) 289 assert "For 'enumerate', the 'start'" in str(ex.value) 290