1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import mindspore as ms 17import mindspore.nn as nn 18from mindspore import context 19from mindspore import Tensor 20from mindspore.ops import operations as P 21from mindspore.common.parameter import Parameter 22from mindspore.common.initializer import initializer 23from mindspore.train.model import Model 24from mindspore.nn.wrap.cell_wrapper import PipelineCell 25 26 27class DatasetLenet(): 28 def __init__(self, data, label, length=3): 29 self.data = data 30 self.label = label 31 self.index = 1 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.data, self.label 42 43 def reset(self): 44 self.index = 0 45 46 def get_dataset_size(self): 47 return 32 48 49 def get_repeat_count(self): 50 return 1 51 52 def get_batch_size(self): 53 return 32 54 55 def create_tuple_iterator(self, num_epochs=1, do_copy=True): 56 return self 57 58 59class MatMulCell(nn.Cell): 60 def __init__(self, strategy1, strategy2): 61 super().__init__() 62 self.param = Parameter(initializer("zeros", [64, 64]), name="param") 63 self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1") 64 self.matmul = P.MatMul().shard(strategy1) 65 self.matmul1 = P.MatMul().shard(strategy2) 66 67 def construct(self, x): 68 out = self.matmul(x, self.param) 69 out = self.matmul1(out, self.param1) 70 return out, self.param 71 72 73class MatMulCell2(nn.Cell): 74 def __init__(self, strategy1, strategy2): 75 super().__init__() 76 self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1") 77 self.matmul = P.MatMul().shard(strategy1) 78 self.matmul1 = P.MatMul().shard(strategy2) 79 80 def construct(self, x, param): 81 out = self.matmul(x, param) 82 out = self.matmul1(out, self.param1) 83 return out 84 85 86class Net(nn.Cell): 87 def __init__(self, strategy1, strategy2, param=None): 88 super().__init__() 89 self.cell1 = MatMulCell(strategy1, strategy2) 90 self.cell1.pipeline_stage = 0 91 self.cell2 = MatMulCell2(strategy1, strategy2) 92 self.cell2.pipeline_stage = 1 93 94 def construct(self, x, label): 95 out, param = self.cell1(x) 96 out = self.cell2(out, param) 97 return out 98 99 100def test_pipeline_split_stage0(): 101 context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2) 102 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 103 data = Tensor(np.ones([32, 64]), dtype=ms.float32) 104 label = Tensor(np.ones([64, 64]), dtype=ms.float32) 105 strategy1 = ((4, 1), (1, 1)) 106 strategy2 = ((2, 1), (1, 1)) 107 net = PipelineCell(Net(strategy1, strategy2), 4) 108 params = net.network.cell1.trainable_params() 109 dataset = DatasetLenet(data, label, 3) 110 optimizer = nn.Lamb(params, learning_rate=0.01) 111 model = Model(net, optimizer=optimizer) 112 model.train(2, dataset, dataset_sink_mode=False) 113 114 115def test_pipeline_split_stage1(): 116 context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2) 117 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") 118 data = Tensor(np.ones([32, 64]), dtype=ms.float32) 119 label = Tensor(np.ones([64, 64]), dtype=ms.float32) 120 strategy1 = ((4, 1), (1, 1)) 121 strategy2 = ((2, 1), (1, 1)) 122 net = PipelineCell(Net(strategy1, strategy2), 4) 123 params = net.network.cell2.trainable_params() 124 dataset = DatasetLenet(data, label, 3) 125 optimizer = nn.Lamb(params, learning_rate=0.01) 126 model = Model(net, optimizer=optimizer) 127 model.train(2, dataset, dataset_sink_mode=False) 128