1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" test enumerate""" 16 17import numpy as np 18 19import mindspore.nn as nn 20from mindspore import Tensor 21from mindspore import context 22from mindspore.ops import operations as P 23from mindspore.ops import composite as C 24 25context.set_context(mode=context.GRAPH_MODE) 26 27 28def test_list_index_1D(): 29 class Net(nn.Cell): 30 def __init__(self): 31 super(Net, self).__init__() 32 33 def construct(self): 34 list_ = [[1], [2, 2], [3, 3, 3]] 35 list_[0] = [100] 36 return list_ 37 38 net = Net() 39 out = net() 40 assert out[0] == [100] 41 assert out[1] == [2, 2] 42 assert out[2] == [3, 3, 3] 43 44 45def test_list_neg_index_1D(): 46 class Net(nn.Cell): 47 def __init__(self): 48 super(Net, self).__init__() 49 50 def construct(self): 51 list_ = [[1], [2, 2], [3, 3, 3]] 52 list_[-3] = [100] 53 return list_ 54 55 net = Net() 56 out = net() 57 assert out[0] == [100] 58 assert out[1] == [2, 2] 59 assert out[2] == [3, 3, 3] 60 61 62def test_list_index_2D(): 63 class Net(nn.Cell): 64 def __init__(self): 65 super(Net, self).__init__() 66 67 def construct(self): 68 list_ = [[1], [2, 2], [3, 3, 3]] 69 list_[1][0] = 200 70 list_[1][1] = 201 71 return list_ 72 73 net = Net() 74 out = net() 75 assert out[0] == [1] 76 assert out[1] == [200, 201] 77 assert out[2] == [3, 3, 3] 78 79 80def test_list_neg_index_2D(): 81 class Net(nn.Cell): 82 def __init__(self): 83 super(Net, self).__init__() 84 85 def construct(self): 86 list_ = [[1], [2, 2], [3, 3, 3]] 87 list_[1][-2] = 200 88 list_[1][-1] = 201 89 return list_ 90 91 net = Net() 92 out = net() 93 assert out[0] == [1] 94 assert out[1] == [200, 201] 95 assert out[2] == [3, 3, 3] 96 97 98def test_list_index_3D(): 99 class Net(nn.Cell): 100 def __init__(self): 101 super(Net, self).__init__() 102 103 def construct(self): 104 list_ = [[1], [2, 2], [[3, 3, 3]]] 105 list_[2][0][0] = 300 106 list_[2][0][1] = 301 107 list_[2][0][2] = 302 108 return list_ 109 110 net = Net() 111 out = net() 112 assert out[0] == [1] 113 assert out[1] == [2, 2] 114 assert out[2] == [[300, 301, 302]] 115 116 117def test_list_neg_index_3D(): 118 class Net(nn.Cell): 119 def __init__(self): 120 super(Net, self).__init__() 121 122 def construct(self): 123 list_ = [[1], [2, 2], [[3, 3, 3]]] 124 list_[2][0][-3] = 300 125 list_[2][0][-2] = 301 126 list_[2][0][-1] = 302 127 return list_ 128 129 net = Net() 130 out = net() 131 assert out[0] == [1] 132 assert out[1] == [2, 2] 133 assert out[2] == [[300, 301, 302]] 134 135 136def test_list_index_1D_parameter(): 137 class Net(nn.Cell): 138 def __init__(self): 139 super(Net, self).__init__() 140 141 def construct(self, x): 142 list_ = [x] 143 list_[0] = 100 144 return list_ 145 146 net = Net() 147 net(Tensor(0)) 148 149 150def test_list_index_2D_parameter(): 151 class Net(nn.Cell): 152 def __init__(self): 153 super(Net, self).__init__() 154 155 def construct(self, x): 156 list_ = [[x, x]] 157 list_[0][0] = 100 158 return list_ 159 160 net = Net() 161 net(Tensor(0)) 162 163 164def test_list_index_3D_parameter(): 165 class Net(nn.Cell): 166 def __init__(self): 167 super(Net, self).__init__() 168 169 def construct(self, x): 170 list_ = [[[x, x]]] 171 list_[0][0][0] = 100 172 return list_ 173 174 net = Net() 175 net(Tensor(0)) 176 177 178def test_const_list_index_3D_bprop(): 179 class Net(nn.Cell): 180 def __init__(self): 181 super(Net, self).__init__() 182 self.value = [[1], [2, 2], [[3, 3], [3, 3]]] 183 self.relu = P.ReLU() 184 185 def construct(self, input_x): 186 list_x = self.value 187 list_x[2][0][1] = input_x 188 return self.relu(list_x[2][0][1]) 189 190 class GradNet(nn.Cell): 191 def __init__(self, net): 192 super(GradNet, self).__init__() 193 self.net = net 194 self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 195 196 def construct(self, x, sens): 197 return self.grad_all_with_sens(self.net)(x, sens) 198 199 net = Net() 200 grad_net = GradNet(net) 201 x = Tensor(np.arange(2 * 3).reshape(2, 3)) 202 sens = Tensor(np.arange(2 * 3).reshape(2, 3)) 203 grad_net(x, sens) 204 205 206def test_parameter_list_index_3D_bprop(): 207 class Net(nn.Cell): 208 def __init__(self): 209 super(Net, self).__init__() 210 self.value = [[1], [2, 2], [[3, 3], [3, 3]]] 211 self.relu = P.ReLU() 212 213 def construct(self, x, value): 214 list_value = [[x], [x, x], [[x, x], [x, x]]] 215 list_value[2][0][1] = value 216 return self.relu(list_value[2][0][1]) 217 218 class GradNet(nn.Cell): 219 def __init__(self, net): 220 super(GradNet, self).__init__() 221 self.net = net 222 self.grad_all_with_sens = C.GradOperation(get_all=True, sens_param=True) 223 224 def construct(self, x, value, sens): 225 return self.grad_all_with_sens(self.net)(x, value, sens) 226 227 net = Net() 228 grad_net = GradNet(net) 229 x = Tensor(np.arange(2 * 3).reshape(2, 3)) 230 value = Tensor(np.ones((2, 3), np.int64)) 231 sens = Tensor(np.arange(2 * 3).reshape(2, 3)) 232 grad_net(x, value, sens) 233