• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.common.parameter import Parameter
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from tests.ut.python.ops.test_math_ops import VirtualLoss
26
27
28grad_all = C.GradOperation(get_all=True)
29
30
31class NetWithLoss(nn.Cell):
32    def __init__(self, network):
33        super(NetWithLoss, self).__init__()
34        self.loss = VirtualLoss()
35        self.network = network
36
37    def construct(self, x):
38        predict = self.network(x)
39        return self.loss(predict)
40
41
42class GradWrap(nn.Cell):
43    def __init__(self, network):
44        super(GradWrap, self).__init__()
45        self.network = network
46
47    def construct(self, x):
48        return grad_all(self.network)(x)
49
50
51def test_reshape_matmul():
52    class Net(nn.Cell):
53        def __init__(self):
54            super().__init__()
55            self.reshape = P.Reshape()
56            self.matmul = P.MatMul()
57            self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
58
59        def construct(self, x):
60            out = self.reshape(x, (64, 28))
61            out = self.matmul(out, self.matmul_weight)
62            return out
63
64    size = 8
65    context.set_auto_parallel_context(device_num=size, global_rank=0)
66    x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
67
68    net = GradWrap(NetWithLoss(Net()))
69    context.set_auto_parallel_context(parallel_mode="auto_parallel")
70    net.set_auto_parallel()
71    net.set_train()
72    _cell_graph_executor.compile(net, x)
73
74def test_reshape_reshape():
75    class Net(nn.Cell):
76        def __init__(self):
77            super().__init__()
78            self.reshape = P.Reshape()
79            self.relu = P.ReLU()
80
81        def construct(self, x):
82            x = self.relu(x)
83            out = self.reshape(x, (64, 28))
84            out = self.reshape(out, (64, 28, 1))
85            return out
86
87    size = 8
88    context.set_auto_parallel_context(device_num=size, global_rank=0)
89    x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
90
91    net = GradWrap(NetWithLoss(Net()))
92    context.set_auto_parallel_context(parallel_mode="auto_parallel")
93    net.set_auto_parallel()
94    net.set_train()
95    _cell_graph_executor.compile(net, x)
96
97
98def test_reshape_auto_1():
99    class Net(nn.Cell):
100        def __init__(self):
101            super().__init__()
102            self.relu = P.ReLU()
103            self.reshape = P.Reshape()
104            self.matmul = P.MatMul()
105            self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
106
107        def construct(self, x):
108            out = self.relu(x)
109            out = self.reshape(out, (64, 28))
110            out = self.matmul(out, self.matmul_weight)
111            return out
112
113    size = 8
114    context.set_auto_parallel_context(device_num=size, global_rank=0)
115    x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
116
117    net = GradWrap(NetWithLoss(Net()))
118    context.set_auto_parallel_context(parallel_mode="auto_parallel")
119    net.set_auto_parallel()
120    net.set_train()
121    _cell_graph_executor.compile(net, x)
122
123
124def test_reshape_auto_2():
125    class Net(nn.Cell):
126        def __init__(self):
127            super().__init__()
128            self.relu = P.ReLU()
129            self.reshape = P.Reshape()
130            self.matmul = P.MatMul()
131            self.add_weight = Parameter(Tensor(np.ones([128, 32]), dtype=ms.float32), name="weight1")
132            self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
133
134        def construct(self, x):
135            out = self.relu(x)
136            out = self.reshape(out, (64, 28))
137            out = self.matmul(out, self.matmul_weight)
138            out = self.reshape(out, (128, 32))
139            out = out + self.add_weight
140            return out
141
142    size = 8
143    context.set_auto_parallel_context(device_num=size, global_rank=0)
144    x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
145
146    net = GradWrap(NetWithLoss(Net()))
147    context.set_auto_parallel_context(parallel_mode="auto_parallel")
148    net.set_auto_parallel()
149    net.set_train()
150    _cell_graph_executor.compile(net, x)
151
152
153def test_reshape_auto_3():
154    class Net(nn.Cell):
155        def __init__(self):
156            super().__init__()
157            self.relu = P.ReLU()
158            self.reshape = P.Reshape()
159            self.matmul = P.MatMul()
160            self.matmul_weight = Parameter(Tensor(np.ones([28, 64]), dtype=ms.float32), name="weight")
161
162        def construct(self, x):
163            out = self.relu(x)
164            out = self.matmul(out, self.matmul_weight)
165            out = self.reshape(out, (8, 8, 8, 8))
166            return out
167
168    size = 8
169    context.set_auto_parallel_context(device_num=size, global_rank=0)
170    x = Tensor(np.ones([8 * size, 28]), dtype=ms.float32)
171
172    net = GradWrap(NetWithLoss(Net()))
173    context.set_auto_parallel_context(parallel_mode="auto_parallel")
174    net.set_auto_parallel()
175    net.set_train()
176    _cell_graph_executor.compile(net, x)
177
178
179def test_reshape_auto_4():
180    class Net(nn.Cell):
181        def __init__(self):
182            super().__init__()
183            self.relu = P.ReLU()
184            self.reshape = P.Reshape()
185            self.matmul = P.MatMul()
186            self.matmul_weight = Parameter(Tensor(np.ones([28 * 64]), dtype=ms.float32), name="weight")
187
188        def construct(self, x):
189            out = self.relu(x)
190            out = self.reshape(out, (64, 28))
191            w = self.reshape(self.matmul_weight, (28, 64))
192            out = self.matmul(out, w)
193            return out
194
195    size = 8
196    context.set_auto_parallel_context(device_num=size, global_rank=0)
197    x = Tensor(np.ones([8 * size, 28, 1, 1]), dtype=ms.float32)
198
199    net = GradWrap(NetWithLoss(Net()))
200    context.set_auto_parallel_context(parallel_mode="auto_parallel")
201    net.set_auto_parallel()
202    net.set_train()
203    _cell_graph_executor.compile(net, x)
204
205
206def test_reshape_auto_5():
207    class NetWithLoss5(nn.Cell):
208        def __init__(self, network):
209            super(NetWithLoss5, self).__init__()
210            self.loss = VirtualLoss()
211            self.network = network
212
213        def construct(self, x, y):
214            predict = self.network(x, y)
215            return self.loss(predict)
216
217    class GradWrap5(nn.Cell):
218        def __init__(self, network):
219            super(GradWrap5, self).__init__()
220            self.network = network
221
222        def construct(self, x, y):
223            return grad_all(self.network)(x, y)
224
225    class Net(nn.Cell):
226        def __init__(self):
227            super().__init__()
228            self.relu = P.ReLU()
229            self.mul = P.Mul()
230            self.reshape = P.Reshape()
231            self.reduce_sum = P.ReduceSum()
232            self.wide_w = Parameter(Tensor(np.ones([4, 1024 * 8, 64]), dtype=ms.float32), name="weight")
233
234        def construct(self, x, y):
235            mask = self.reshape(y, (4, 1024 * 8, 1))
236            w_id = self.relu(x)
237            wx = self.mul(w_id, mask)
238            wide_out = self.reshape(self.reduce_sum(wx, 1), (-1, 1))
239            deep_id = x + self.wide_w
240            vx = self.mul(deep_id, mask)
241            deep_in = self.reshape(vx, (-1, 1024 * 8 * 64))
242            out = wide_out + deep_in
243            return out
244
245    size = 8
246    context.set_auto_parallel_context(device_num=size, global_rank=0)
247    x = Tensor(np.ones([4, 1024 * size, 1]), dtype=ms.float32)
248    y = Tensor(np.ones([4, 1024 * size,]), dtype=ms.float32)
249
250    net = GradWrap5(NetWithLoss5(Net()))
251    context.set_auto_parallel_context(parallel_mode="auto_parallel")
252    net.set_auto_parallel()
253    net.set_train()
254    _cell_graph_executor.compile(net, x, y)
255
256def test_reshape_auto_6():
257    class NetWithLoss6(nn.Cell):
258        def __init__(self, network):
259            super(NetWithLoss6, self).__init__()
260            self.loss = VirtualLoss()
261            self.network = network
262
263        def construct(self, x, y):
264            predict = self.network(x, y)
265            return self.loss(predict)
266
267    class GradWrap6(nn.Cell):
268        def __init__(self, network):
269            super(GradWrap6, self).__init__()
270            self.network = network
271
272        def construct(self, x, y):
273            return grad_all(self.network)(x, y)
274
275    class Net(nn.Cell):
276        def __init__(self):
277            super().__init__()
278            self.relu = P.ReLU()
279            self.mul = P.Mul()
280            self.reshape = P.Reshape()
281            self.reduce_mean = P.ReduceMean()
282            self.wide_w = Parameter(Tensor(np.ones([4, 1024, 1]), dtype=ms.float32), name="weight")
283
284        def construct(self, x, y):
285            out1 = x + self.wide_w
286            w = self.reshape(self.wide_w, (4, 1024))
287            out1 = self.reduce_mean(out1, 1)
288            out1 = out1 - w
289            out2 = self.mul(y, w)
290            out = out1 + out2
291            return out
292
293    size = 8
294    context.set_auto_parallel_context(device_num=size, global_rank=0)
295    x = Tensor(np.ones([4, 1024, 1]), dtype=ms.float32)
296    y = Tensor(np.ones([4, 1024,]), dtype=ms.float32)
297
298    net = GradWrap6(NetWithLoss6(Net()))
299    context.set_auto_parallel_context(parallel_mode="auto_parallel")
300    net.set_auto_parallel()
301    net.set_train()
302    _cell_graph_executor.compile(net, x, y)
303
304def test_reshape_auto_7():
305    class Net(nn.Cell):
306        def __init__(self):
307            super().__init__()
308            self.reshape = P.Reshape()
309            self.mul = P.Mul().shard(((1, 2, 4), (2, 4)))
310            self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
311
312        def construct(self, x):
313            weight = self.reshape(self.mul_weight, (1, 128, 96))
314            out = self.mul(weight, self.mul_weight)
315            return out
316
317    size = 8
318    context.set_auto_parallel_context(device_num=size, global_rank=0)
319    x = Tensor(np.ones([128, 28]), dtype=ms.float32)
320
321    net = GradWrap(NetWithLoss(Net()))
322    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
323    net.set_auto_parallel()
324    net.set_train()
325    _cell_graph_executor.compile(net, x)
326
327def test_reshape_depend_reshape():
328    class Net(nn.Cell):
329        def __init__(self):
330            super().__init__()
331            self.reshape1 = P.Reshape()
332            self.reshape2 = P.Reshape()
333            self.relu = P.ReLU()
334            self.depend = P.Depend()
335            self.mul = P.Mul().shard(((2, 4), (2, 4)))
336            self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
337            self.add = P.Add().shard(((4, 2), (4, 2)))
338
339        def construct(self, x, y):
340            out1 = self.mul(x, self.mul_weight)
341            y = self.relu(y)
342            out2 = self.reshape1(y, (96, 32, 4))
343            out3 = self.depend(out2, out1)
344            out3 = self.reshape2(out3, (128, 96))
345            out = out1 + out3
346            return out
347
348    class NetWithLoss1(nn.Cell):
349        def __init__(self, network):
350            super(NetWithLoss1, self).__init__()
351            self.mean = P.ReduceMean(keep_dims=False)
352            self.network = network
353
354        def construct(self, x, y):
355            predict = self.network(x, y)
356            return self.mean(predict, ())
357
358    class GradWrap1(nn.Cell):
359        def __init__(self, network):
360            super(GradWrap1, self).__init__()
361            self.network = network
362
363        def construct(self, x, y):
364            return grad_all(self.network)(x, y)
365
366    size = 8
367    context.set_auto_parallel_context(device_num=size, global_rank=0)
368    x = Tensor(np.ones([128, 96]), dtype=ms.float32)
369    y = Tensor(np.ones([256, 48]), dtype=ms.float32)
370    net = GradWrap1(NetWithLoss1(Net()))
371    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
372    net.set_auto_parallel()
373    net.set_train()
374    _cell_graph_executor.compile(net, x, y)
375    net_auto = GradWrap1(NetWithLoss1(Net()))
376    context.set_auto_parallel_context(parallel_mode="auto_parallel")
377    net_auto.set_auto_parallel()
378    net_auto.set_train()
379    _cell_graph_executor.compile(net_auto, x, y)
380