1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import os 16import numpy as np 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from mindspore.train.callback import CheckpointConfig, ModelCheckpoint 24from tests.dataset_mock import MindData 25 26 27class Dataset(MindData): 28 def __init__(self, predict, label, length=3): 29 super(Dataset, self).__init__(size=length) 30 self.predict = predict 31 self.label = label 32 self.index = 0 33 self.length = length 34 35 def __iter__(self): 36 return self 37 38 def __next__(self): 39 if self.index >= self.length: 40 raise StopIteration 41 self.index += 1 42 return self.predict, self.label 43 44 def reset(self): 45 self.index = 0 46 47 48class Net(Cell): 49 def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, mask=0): 50 super().__init__() 51 self.mul = P.Mul().shard(strategy1) 52 self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2) 53 self.weight = Parameter(weight, "w1") 54 self.mul2 = P.Mul() 55 self.weight2 = Parameter(w2, "w2") 56 self.begin = begin 57 self.end = end 58 self.strides = strides 59 60 def construct(self, x, b): 61 out = self.strided_slice( 62 self.weight, self.begin, self.end, self.strides) 63 out = self.mul(x, out) 64 out = self.mul2(out, self.weight2) 65 return out 66 67 68_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32) 69_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 70_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) 71_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) 72 73 74def clean_all_ckpt_files(folder_path): 75 if os.path.exists(folder_path): 76 for file_name in os.listdir(folder_path): 77 if file_name.endswith('.ckpt') or file_name.endswith('.meta'): 78 os.remove(os.path.join(folder_path, file_name)) 79 80 81def compile_net(net): 82 learning_rate = 0.1 83 momentum = 0.9 84 epoch_size = 2 85 dataset = Dataset(_x, _b) 86 opt = Momentum(net.trainable_params(), learning_rate, momentum) 87 model = Model(net, optimizer=opt) 88 ckpt_config = CheckpointConfig(keep_checkpoint_max=1) 89 ckpt_path = "./parallel_ckpt" 90 ckpt_cb = ModelCheckpoint(prefix="parallel", directory=ckpt_path, config=ckpt_config) 91 model.train(epoch_size, dataset, dataset_sink_mode=False, callbacks=[ckpt_cb]) 92 assert len(model._train_network.parallel_parameter_merge_net_dict) == 4 93 clean_all_ckpt_files(ckpt_path) 94 context.reset_auto_parallel_context() 95 96 97def test_stridedslice_parameter(): 98 context.set_auto_parallel_context( 99 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 100 strategy1 = ((1, 4, 1), (1, 4, 2)) 101 strategy2 = ((1, 4, 2),) 102 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 103 strategy1, strategy2) 104 compile_net(net) 105