• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.context import set_auto_parallel_context, reset_auto_parallel_context
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from tests.ut.python.ops.test_math_ops import VirtualLoss
26
27
28grad_all = C.GradOperation(get_all=True)
29
30
31# model_parallel test
32def test_six_matmul_save():
33    class NetWithLoss(nn.Cell):
34        def __init__(self, network):
35            super(NetWithLoss, self).__init__()
36            self.loss = VirtualLoss()
37            self.network = network
38
39        def construct(self, x1, x6):
40            predict = self.network(x1, x6)
41            return self.loss(predict)
42
43    class GradWrap(nn.Cell):
44        def __init__(self, network):
45            super(GradWrap, self).__init__()
46            self.network = network
47
48        def construct(self, x1, x6):
49            return grad_all(self.network)(x1, x6)
50
51    class Net(nn.Cell):
52        def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
53            super().__init__()
54            self.matmul1 = P.MatMul().shard(strategy1)
55            self.matmul2 = P.MatMul().shard(strategy2)
56            self.matmul3 = P.MatMul().shard(strategy3)
57            self.matmul4 = P.MatMul().shard(strategy4)
58            self.matmul5 = P.MatMul().shard(strategy5)
59            self.matmul6 = P.MatMul().shard(strategy6)
60            self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
61            self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
62            self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
63            self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
64            self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
65            self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
66
67        def construct(self, x1, x6):
68            out = self.matmul1(x1, self.weight1)
69            out = self.matmul2(out, self.weight2)
70            out = self.matmul3(out, self.weight3)
71            out = self.matmul4(out, self.weight4)
72            out = self.matmul5(out, self.weight5)
73            out = out + self.weight6
74            out = self.matmul6(out, x6)
75            return out
76
77    reset_auto_parallel_context()
78    set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
79                              group_ckpt_save_file="./group_stage1.ckpt")
80    strategy1 = ((8, 1), (1, 1))
81    strategy2 = ((1, 8), (8, 1))
82    strategy3 = ((2, 2), (2, 2))
83    strategy4 = ((1, 1), (1, 8))
84    strategy5 = ((4, 2), (2, 1))
85    strategy6 = ((4, 1), (1, 2))
86    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
87    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
88    net.set_auto_parallel()
89    x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
90    x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
91    net.set_train()
92    _cell_graph_executor.compile(net, x1, x6)
93
94
95# remove matmul2, add matmul7
96def test_six_matmul_load():
97    class NetWithLoss(nn.Cell):
98        def __init__(self, network):
99            super(NetWithLoss, self).__init__()
100            self.loss = VirtualLoss()
101            self.network = network
102
103        def construct(self, x1, x6, x7):
104            predict = self.network(x1, x6, x7)
105            return self.loss(predict)
106
107    class GradWrap(nn.Cell):
108        def __init__(self, network):
109            super(GradWrap, self).__init__()
110            self.network = network
111
112        def construct(self, x1, x6, x7):
113            return grad_all(self.network)(x1, x6, x7)
114
115    class Net(nn.Cell):
116        def __init__(self, strategy1, strategy3, strategy4, strategy5, strategy6, strategy7):
117            super().__init__()
118            self.matmul1 = P.MatMul().shard(strategy1)
119            self.matmul3 = P.MatMul().shard(strategy3)
120            self.matmul4 = P.MatMul().shard(strategy4)
121            self.matmul5 = P.MatMul().shard(strategy5)
122            self.matmul6 = P.MatMul().shard(strategy6)
123            self.matmul7 = P.MatMul().shard(strategy7)
124            self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
125            self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
126            self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
127            self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
128            self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
129
130        def construct(self, x1, x6, x7):
131            out = self.matmul1(x1, self.weight1)
132            out = self.matmul3(out, self.weight3)
133            out = self.matmul4(out, self.weight4)
134            out = self.matmul5(out, self.weight5)
135            out = out + self.weight6
136            out = self.matmul6(out, x6)
137            out = self.matmul7(out, x7)
138            return out
139
140    reset_auto_parallel_context()
141    set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
142                              group_ckpt_save_file="./group_stage1.ckpt")
143    strategy1 = ((8, 1), (1, 1))
144    strategy3 = ((8, 1), (1, 1))
145    strategy4 = ((8, 1), (1, 1))
146    strategy5 = ((8, 1), (1, 1))
147    strategy6 = ((8, 1), (1, 1))
148    strategy7 = ((8, 1), (1, 1))
149    net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5, strategy6, strategy7)))
150    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
151    net.set_auto_parallel()
152    x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
153    x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
154    x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
155    net.set_train()
156    _cell_graph_executor.compile(net, x1, x6, x7)
157
158
159# model_parallel test
160def test_six_matmul_save_auto():
161    class NetWithLoss(nn.Cell):
162        def __init__(self, network):
163            super(NetWithLoss, self).__init__()
164            self.loss = VirtualLoss()
165            self.network = network
166
167        def construct(self, x1, x6):
168            predict = self.network(x1, x6)
169            return self.loss(predict)
170
171    class GradWrap(nn.Cell):
172        def __init__(self, network):
173            super(GradWrap, self).__init__()
174            self.network = network
175
176        def construct(self, x1, x6):
177            return grad_all(self.network)(x1, x6)
178
179    class Net(nn.Cell):
180        def __init__(self):
181            super().__init__()
182            self.matmul1 = P.MatMul()
183            self.matmul2 = P.MatMul()
184            self.matmul3 = P.MatMul()
185            self.matmul4 = P.MatMul()
186            self.matmul5 = P.MatMul()
187            self.matmul6 = P.MatMul()
188            self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
189            self.weight2 = Parameter(Tensor(np.ones([64, 64]), dtype=ms.float32), name="weight2")
190            self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
191            self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
192            self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
193            self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
194
195        def construct(self, x1, x6):
196            out = self.matmul1(x1, self.weight1)
197            out = self.matmul2(out, self.weight2)
198            out = self.matmul3(out, self.weight3)
199            out = self.matmul4(out, self.weight4)
200            out = self.matmul5(out, self.weight5)
201            out = out + self.weight6
202            out = self.matmul6(out, x6)
203            return out
204
205    reset_auto_parallel_context()
206    set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1_auto.ckpt")
207    net = GradWrap(NetWithLoss(Net()))
208    context.set_auto_parallel_context(parallel_mode="auto_parallel")
209    net.set_auto_parallel()
210    x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
211    x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
212    net.set_train()
213    _cell_graph_executor.compile(net, x1, x6)
214
215
216# remove matmul2, add matmul7
217def test_six_matmul_load_auto():
218    class NetWithLoss(nn.Cell):
219        def __init__(self, network):
220            super(NetWithLoss, self).__init__()
221            self.loss = VirtualLoss()
222            self.network = network
223
224        def construct(self, x1, x6, x7):
225            predict = self.network(x1, x6, x7)
226            return self.loss(predict)
227
228    class GradWrap(nn.Cell):
229        def __init__(self, network):
230            super(GradWrap, self).__init__()
231            self.network = network
232
233        def construct(self, x1, x6, x7):
234            return grad_all(self.network)(x1, x6, x7)
235
236    class Net(nn.Cell):
237        def __init__(self, strategy1, strategy3, strategy4, strategy5):
238            super().__init__()
239            self.matmul1 = P.MatMul().shard(strategy1)
240            self.matmul3 = P.MatMul().shard(strategy3)
241            self.matmul4 = P.MatMul().shard(strategy4)
242            self.matmul5 = P.MatMul().shard(strategy5)
243            self.matmul6 = P.MatMul()
244            self.matmul7 = P.MatMul()
245            self.weight1 = Parameter(Tensor(np.ones([32, 64]), dtype=ms.float32), name="weight1")
246            self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
247            self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
248            self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
249            self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
250
251        def construct(self, x1, x6, x7):
252            out = self.matmul1(x1, self.weight1)
253            out = self.matmul3(out, self.weight3)
254            out = self.matmul4(out, self.weight4)
255            out = self.matmul5(out, self.weight5)
256            out = out + self.weight6
257            out = self.matmul6(out, x6)
258            out = self.matmul7(out, x7)
259            return out
260
261    reset_auto_parallel_context()
262    set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1_auto.ckpt")
263    strategy1 = ((2, 2), (2, 2))
264    strategy3 = ((2, 2), (2, 2))
265    strategy4 = ((2, 2), (2, 2))
266    strategy5 = ((2, 2), (2, 2))
267    net = GradWrap(NetWithLoss(Net(strategy1, strategy3, strategy4, strategy5)))
268    context.set_auto_parallel_context(parallel_mode="auto_parallel")
269    net.set_auto_parallel()
270    x1 = Tensor(np.ones([32, 32]), dtype=ms.float32)
271    x6 = Tensor(np.ones([128, 32]), dtype=ms.float32)
272    x7 = Tensor(np.ones([32, 32]), dtype=ms.float32)
273    net.set_train()
274    _cell_graph_executor.compile(net, x1, x6, x7)
275