• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import mindspore as ms
17import mindspore.nn as nn
18from mindspore import context
19from mindspore import Tensor
20from mindspore.ops import operations as P
21from mindspore.common.parameter import Parameter
22from mindspore.common.initializer import initializer
23from mindspore.train.model import Model
24from mindspore.nn.wrap.cell_wrapper import PipelineCell
25
26
27class DatasetLenet():
28    def __init__(self, data, label, length=3):
29        self.data = data
30        self.label = label
31        self.index = 1
32        self.length = length
33
34    def __iter__(self):
35        return self
36
37    def __next__(self):
38        if self.index >= self.length:
39            raise StopIteration
40        self.index += 1
41        return self.data, self.label
42
43    def reset(self):
44        self.index = 0
45
46    def get_dataset_size(self):
47        return 32
48
49    def get_repeat_count(self):
50        return 1
51
52    def get_batch_size(self):
53        return 32
54
55    def create_tuple_iterator(self, num_epochs=1, do_copy=True):
56        return self
57
58
59class MatMulCell(nn.Cell):
60    def __init__(self, strategy1, strategy2):
61        super().__init__()
62        self.param = Parameter(initializer("zeros", [64, 64]), name="param")
63        self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
64        self.matmul = P.MatMul().shard(strategy1)
65        self.matmul1 = P.MatMul().shard(strategy2)
66
67    def construct(self, x):
68        out = self.matmul(x, self.param)
69        out = self.matmul1(out, self.param1)
70        return out, self.param
71
72
73class MatMulCell2(nn.Cell):
74    def __init__(self, strategy1, strategy2):
75        super().__init__()
76        self.param1 = Parameter(initializer("zeros", [64, 64]), name="param1")
77        self.matmul = P.MatMul().shard(strategy1)
78        self.matmul1 = P.MatMul().shard(strategy2)
79
80    def construct(self, x, param):
81        out = self.matmul(x, param)
82        out = self.matmul1(out, self.param1)
83        return out
84
85
86class Net(nn.Cell):
87    def __init__(self, strategy1, strategy2, param=None):
88        super().__init__()
89        self.cell1 = MatMulCell(strategy1, strategy2)
90        self.cell1.pipeline_stage = 0
91        self.cell2 = MatMulCell2(strategy1, strategy2)
92        self.cell2.pipeline_stage = 1
93
94    def construct(self, x, label):
95        out, param = self.cell1(x)
96        out = self.cell2(out, param)
97        return out
98
99
100def test_pipeline_split_stage0():
101    context.set_auto_parallel_context(device_num=8, global_rank=0, pipeline_stages=2)
102    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
103    data = Tensor(np.ones([32, 64]), dtype=ms.float32)
104    label = Tensor(np.ones([64, 64]), dtype=ms.float32)
105    strategy1 = ((4, 1), (1, 1))
106    strategy2 = ((2, 1), (1, 1))
107    net = PipelineCell(Net(strategy1, strategy2), 4)
108    params = net.network.cell1.trainable_params()
109    dataset = DatasetLenet(data, label, 3)
110    optimizer = nn.Lamb(params, learning_rate=0.01)
111    model = Model(net, optimizer=optimizer)
112    model.train(2, dataset, dataset_sink_mode=False)
113
114
115def test_pipeline_split_stage1():
116    context.set_auto_parallel_context(device_num=8, global_rank=4, pipeline_stages=2)
117    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
118    data = Tensor(np.ones([32, 64]), dtype=ms.float32)
119    label = Tensor(np.ones([64, 64]), dtype=ms.float32)
120    strategy1 = ((4, 1), (1, 1))
121    strategy2 = ((2, 1), (1, 1))
122    net = PipelineCell(Net(strategy1, strategy2), 4)
123    params = net.network.cell2.trainable_params()
124    dataset = DatasetLenet(data, label, 3)
125    optimizer = nn.Lamb(params, learning_rate=0.01)
126    model = Model(net, optimizer=optimizer)
127    model.train(2, dataset, dataset_sink_mode=False)
128