1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from tests.dataset_mock import MindData 24 25 26class Dataset(MindData): 27 def __init__(self, predict, label, length=3): 28 super(Dataset, self).__init__(size=length) 29 self.predict = predict 30 self.label = label 31 self.index = 0 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.predict, self.label 42 43 def reset(self): 44 self.index = 0 45 46 47class Net(Cell): 48 def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None): 49 super().__init__() 50 self.mul = P.Mul().shard(strategy1) 51 self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1") 52 self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32) 53 self.gathernd = P.GatherNd().shard(strategy2) 54 self.relu = P.ReLU().shard(strategy3) 55 56 def construct(self, x, b): 57 out = self.mul(x, self.w1) 58 out = self.gathernd(out, self.indices) 59 out = self.relu(out) 60 return out 61 62 63class Net2(Cell): 64 def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None): 65 super().__init__() 66 self.mul = P.Mul().shard(strategy1) 67 self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1") 68 self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32) 69 self.gathernd = P.GatherNd().shard(strategy2) 70 self.relu = P.ReLU().shard(strategy3) 71 72 def construct(self, x, b): 73 out = self.mul(x, self.w1) 74 out = self.gathernd(out, self.indices) 75 return out 76 77 78class Net3(Cell): 79 def __init__(self, w1_shape, indices_shape, strategy1=None, strategy2=None, strategy3=None): 80 super().__init__() 81 self.mul = P.Mul().shard(strategy1) 82 self.w1 = Parameter(Tensor(np.ones(w1_shape), dtype=ms.float32), "w1") 83 self.indices = Tensor(np.ones(indices_shape), dtype=ms.int32) 84 self.gathernd = P.GatherNd().shard(strategy2) 85 self.relu = P.ReLU().shard(strategy3) 86 87 def construct(self, x, b): 88 out = self.gathernd(x, self.indices) 89 out = self.relu(out) 90 out = self.mul(out, self.w1) 91 return out 92 93 94# full_batch = false 95_x = Tensor(np.ones([1, 16, 32]), dtype=ms.float32) 96_b = Tensor(np.ones([1, 16, 32]), dtype=ms.float32) 97 98 99def compile_net(net): 100 learning_rate = 0.1 101 momentum = 0.9 102 epoch_size = 2 103 dataset = Dataset(_x, _b) 104 opt = Momentum(net.trainable_params(), learning_rate, momentum) 105 model = Model(net, optimizer=opt) 106 model.train(epoch_size, dataset, dataset_sink_mode=False) 107 context.reset_auto_parallel_context() 108 109 110def test_gathernd_data_parallel(): 111 context.set_auto_parallel_context( 112 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 113 w1_shape = [8, 16, 32] 114 indices_shape = [8, 4, 2, 1] 115 strategy1 = ((8, 1, 1), (8, 1, 1)) 116 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 117 strategy3 = ((8, 1, 1, 1, 1),) 118 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 119 compile_net(net) 120 121 122def test_gathernd_data_parallel2(): 123 context.set_auto_parallel_context( 124 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 125 w1_shape = [8, 16, 32] 126 indices_shape = [8, 4, 2, 2] 127 strategy1 = ((8, 1, 1), (8, 1, 1)) 128 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 129 strategy3 = ((8, 1, 1, 1),) 130 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 131 compile_net(net) 132 133 134def test_gathernd_data_parallel3(): 135 context.set_auto_parallel_context( 136 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 137 w1_shape = [8, 16, 32] 138 indices_shape = [8, 4, 2, 3] 139 strategy1 = ((8, 1, 1), (8, 1, 1)) 140 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 141 strategy3 = ((8, 1, 1),) 142 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 143 compile_net(net) 144 145 146def test_gathernd_data_parallel4(): 147 context.set_auto_parallel_context( 148 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 149 w1_shape = [8, 16, 32] 150 indices_shape = [8, 4, 2, 1] 151 strategy1 = ((8, 1, 1), (8, 1, 1)) 152 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 153 strategy3 = ((8, 1, 1, 1, 1),) 154 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 155 compile_net(net) 156 157 158def test_gathernd_data_parallel5(): 159 context.set_auto_parallel_context( 160 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 161 w1_shape = [8, 16, 32] 162 indices_shape = [8, 4, 2, 2] 163 strategy1 = ((8, 1, 1), (8, 1, 1)) 164 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 165 strategy3 = ((8, 1, 1, 1),) 166 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 167 compile_net(net) 168 169 170def test_gathernd_data_parallel6(): 171 context.set_auto_parallel_context( 172 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 173 w1_shape = [8, 16, 32] 174 indices_shape = [8, 4, 2, 3] 175 strategy1 = ((8, 1, 1), (8, 1, 1)) 176 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 177 strategy3 = ((8, 1, 1),) 178 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 179 compile_net(net) 180 181 182def test_gathernd_data_parallel7(): 183 context.set_auto_parallel_context( 184 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 185 w1_shape = [8, 4, 2, 16, 32] 186 indices_shape = [8, 4, 2, 1] 187 strategy1 = ((8, 1, 1, 1, 1), (8, 1, 1, 1, 1)) 188 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 189 strategy3 = ((8, 1, 1, 1, 1),) 190 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 191 compile_net(net) 192 193 194def test_gathernd_data_parallel8(): 195 context.set_auto_parallel_context( 196 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 197 w1_shape = [8, 4, 2, 32] 198 indices_shape = [8, 4, 2, 2] 199 strategy1 = ((8, 1, 1, 1), (8, 1, 1, 1)) 200 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 201 strategy3 = ((8, 1, 1, 1),) 202 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 203 compile_net(net) 204 205 206def test_gathernd_data_parallel9(): 207 context.set_auto_parallel_context( 208 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 209 w1_shape = [8, 4, 2] 210 indices_shape = [8, 4, 2, 3] 211 strategy1 = ((8, 1, 1), (8, 1, 1)) 212 strategy2 = ((1, 1, 1), (8, 1, 1, 1)) 213 strategy3 = ((8, 1, 1),) 214 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 215 compile_net(net) 216 217 218def test_gathernd_model_parallel(): 219 context.set_auto_parallel_context( 220 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 221 w1_shape = [8, 16, 32] 222 indices_shape = [8, 4, 2, 1] 223 strategy1 = ((8, 1, 1), (8, 1, 1)) 224 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 225 strategy3 = ((8, 1, 1, 1, 1),) 226 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 227 compile_net(net) 228 229 230def test_gathernd_model_parallel2(): 231 context.set_auto_parallel_context( 232 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 233 w1_shape = [8, 16, 32] 234 indices_shape = [8, 4, 2, 2] 235 strategy1 = ((8, 1, 1), (8, 1, 1)) 236 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 237 strategy3 = ((8, 1, 1, 1),) 238 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 239 compile_net(net) 240 241 242def test_gathernd_model_parallel3(): 243 context.set_auto_parallel_context( 244 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 245 w1_shape = [8, 16, 32] 246 indices_shape = [8, 4, 2, 3] 247 strategy1 = ((8, 1, 1), (8, 1, 1)) 248 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 249 strategy3 = ((8, 1, 1),) 250 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 251 compile_net(net) 252 253 254def test_gathernd_model_parallel4(): 255 context.set_auto_parallel_context( 256 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 257 w1_shape = [8, 16, 32] 258 indices_shape = [8, 4, 2, 1] 259 strategy1 = ((8, 1, 1), (8, 1, 1)) 260 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 261 strategy3 = ((8, 1, 1, 1, 1),) 262 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 263 compile_net(net) 264 265 266def test_gathernd_model_parallel5(): 267 context.set_auto_parallel_context( 268 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 269 w1_shape = [8, 16, 32] 270 indices_shape = [8, 4, 2, 2] 271 strategy1 = ((8, 1, 1), (8, 1, 1)) 272 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 273 strategy3 = ((8, 1, 1, 1),) 274 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 275 compile_net(net) 276 277 278def test_gathernd_model_parallel6(): 279 context.set_auto_parallel_context( 280 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 281 w1_shape = [8, 16, 32] 282 indices_shape = [8, 4, 2, 3] 283 strategy1 = ((8, 1, 1), (8, 1, 1)) 284 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 285 strategy3 = ((8, 1, 1),) 286 net = Net2(w1_shape, indices_shape, strategy1, strategy2, strategy3) 287 compile_net(net) 288 289 290def test_gathernd_model_parallel7(): 291 context.set_auto_parallel_context( 292 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 293 w1_shape = [8, 4, 2, 16, 32] 294 indices_shape = [8, 4, 2, 1] 295 strategy1 = ((8, 1, 1, 1, 1), (8, 1, 1, 1, 1)) 296 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 297 strategy3 = ((8, 1, 1, 1, 1),) 298 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 299 compile_net(net) 300 301 302def test_gathernd_model_parallel8(): 303 context.set_auto_parallel_context( 304 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 305 w1_shape = [8, 4, 2, 32] 306 indices_shape = [8, 4, 2, 2] 307 strategy1 = ((8, 1, 1, 1), (8, 1, 1, 1)) 308 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 309 strategy3 = ((8, 1, 1, 1),) 310 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 311 compile_net(net) 312 313 314def test_gathernd_model_parallel9(): 315 context.set_auto_parallel_context( 316 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 317 w1_shape = [8, 4, 2] 318 indices_shape = [8, 4, 2, 3] 319 strategy1 = ((8, 1, 1), (8, 1, 1)) 320 strategy2 = ((1, 1, 1), (2, 2, 2, 1)) 321 strategy3 = ((8, 1, 1),) 322 net = Net3(w1_shape, indices_shape, strategy1, strategy2, strategy3) 323 compile_net(net) 324 325def test_gathernd_auto_parallel(): 326 context.set_auto_parallel_context( 327 parallel_mode="auto_parallel", device_num=8, global_rank=0) 328 w1_shape = [8, 16, 32] 329 indices_shape = [8, 4, 2, 1] 330 net = Net(w1_shape, indices_shape) 331 compile_net(net) 332 333 334def test_gathernd_auto_parallel2(): 335 context.set_auto_parallel_context( 336 parallel_mode="auto_parallel", device_num=8, global_rank=0) 337 w1_shape = [8, 16, 32] 338 indices_shape = [8, 4, 2, 2] 339 net = Net(w1_shape, indices_shape) 340 compile_net(net) 341 342 343def test_gathernd_auto_parallel3(): 344 context.set_auto_parallel_context( 345 parallel_mode="auto_parallel", device_num=8, global_rank=0) 346 w1_shape = [8, 16, 32] 347 indices_shape = [8, 4, 2, 3] 348 net = Net(w1_shape, indices_shape) 349 compile_net(net) 350 351 352def test_gathernd_strategy_error(): 353 context.set_auto_parallel_context( 354 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 355 w1_shape = [8, 16, 32] 356 indices_shape = [8, 4, 2, 3] 357 strategy1 = ((8, 1, 1), (8, 1, 1)) 358 strategy2 = ((2, 1, 1), (1, 2, 2, 1)) 359 strategy3 = ((8, 1, 1),) 360 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 361 with pytest.raises(RuntimeError): 362 compile_net(net) 363 364 365def test_gathernd_strategy_error2(): 366 context.set_auto_parallel_context( 367 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 368 w1_shape = [8, 16, 32] 369 indices_shape = [8, 4, 2, 3] 370 strategy1 = ((8, 1, 1), (8, 1, 1)) 371 strategy2 = ((1, 1, 1), (1, 2, 2, 2)) 372 strategy3 = ((8, 1, 1),) 373 net = Net(w1_shape, indices_shape, strategy1, strategy2, strategy3) 374 with pytest.raises(RuntimeError): 375 compile_net(net) 376