• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.common.parameter import Parameter
23from mindspore.ops import composite as C
24from mindspore.ops import operations as P
25from tests.ut.python.ops.test_math_ops import VirtualLoss
26
27
28grad_all = C.GradOperation(get_all=True)
29
30
31class NetWithLoss(nn.Cell):
32    def __init__(self, network):
33        super(NetWithLoss, self).__init__()
34        self.loss = VirtualLoss()
35        self.network = network
36
37    def construct(self, x):
38        predict = self.network(x)
39        return self.loss(predict)
40
41
42class GradWrap(nn.Cell):
43    def __init__(self, network):
44        super(GradWrap, self).__init__()
45        self.network = network
46
47    def construct(self, x):
48        return grad_all(self.network)(x)
49
50def test_reshape_unexpand():
51    class Net(nn.Cell):
52        def __init__(self):
53            super().__init__()
54            self.reshape = P.Reshape()
55            self.mul = P.Mul().shard(((1, 8), (1, 1, 8)))
56            self.mul_weight = Parameter(Tensor(np.ones([96, 128]), dtype=ms.float32), name="weight")
57
58        def construct(self, x):
59            weight = self.reshape(self.mul_weight, (1, 128, 96))
60            out = self.mul(x, weight)
61            return out
62
63    size = 8
64    context.set_auto_parallel_context(device_num=size, global_rank=0)
65    x = Tensor(np.ones([128, 96]), dtype=ms.float32)
66
67    net = GradWrap(NetWithLoss(Net()))
68    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
69    net.set_auto_parallel()
70    net.set_train()
71    _cell_graph_executor.compile(net, x)
72
73def test_reshape_unexpand_1():
74    class Net(nn.Cell):
75        def __init__(self):
76            super().__init__()
77            self.reshape = P.Reshape()
78            self.mul = P.Mul().shard(((1, 1, 8), (1, 8)))
79            self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
80
81        def construct(self, data):
82            x = self.reshape(self.mul_weight, (1, 128, 96))
83            out = self.mul(x, self.mul_weight)
84            return out
85
86    size = 8
87    context.set_auto_parallel_context(device_num=size, global_rank=0)
88    x = Tensor(np.ones([128, 96]), dtype=ms.float32)
89
90    net = GradWrap(NetWithLoss(Net()))
91    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
92    net.set_auto_parallel()
93    net.set_train()
94    _cell_graph_executor.compile(net, x)
95
96def test_reshape_unexpand_2():
97    class Net(nn.Cell):
98        def __init__(self):
99            super().__init__()
100            self.reshape = P.Reshape()
101            self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
102            self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
103
104        def construct(self, data):
105            x = self.reshape(self.mul_weight, (1, 128, 96))
106            out = self.mul(x, self.mul_weight)
107            return out
108
109    size = 8
110    context.set_auto_parallel_context(device_num=size, global_rank=0)
111    x = Tensor(np.ones([128, 96]), dtype=ms.float32)
112
113    net = GradWrap(NetWithLoss(Net()))
114    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
115    net.set_auto_parallel()
116    net.set_train()
117    _cell_graph_executor.compile(net, x)
118
119def test_reshape_unexpand_3():
120    class Net(nn.Cell):
121        def __init__(self):
122            super().__init__()
123            self.reshape = P.Reshape()
124            self.relu1 = P.ReLU().shard(((4, 1),))
125            self.relu2 = P.ReLU().shard(((1, 4),))
126
127        def construct(self, data):
128            x = self.relu1(data)
129            x = self.reshape(x, (3, 4))
130            x = self.relu2(x)
131            return x
132
133    size = 4
134    context.set_auto_parallel_context(device_num=size, global_rank=0)
135    x = Tensor(np.ones([4, 3]), dtype=ms.float32)
136
137    net = GradWrap(NetWithLoss(Net()))
138    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
139    net.set_auto_parallel()
140    net.set_train()
141    _cell_graph_executor.compile(net, x)
142
143def test_reshape_unexpand_4():
144    class Net(nn.Cell):
145        def __init__(self):
146            super().__init__()
147            self.reshape = P.Reshape()
148            self.relu1 = P.ReLU().shard(((4, 1),))
149            self.relu2 = P.ReLU().shard(((1, 2, 2),))
150
151        def construct(self, data):
152            x = self.relu1(data)
153            x = self.reshape(x, (3, 2, 2))
154            x = self.relu2(x)
155            return x
156
157    size = 4
158    context.set_auto_parallel_context(device_num=size, global_rank=0)
159    x = Tensor(np.ones([4, 3]), dtype=ms.float32)
160
161    net = GradWrap(NetWithLoss(Net()))
162    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
163    net.set_auto_parallel()
164    net.set_train()
165    _cell_graph_executor.compile(net, x)
166
167def test_reshape_unexpand_5():
168    class Net(nn.Cell):
169        def __init__(self):
170            super().__init__()
171            self.reshape = P.Reshape()
172            self.relu1 = P.ReLU().shard(((2, 2, 1),))
173            self.relu2 = P.ReLU().shard(((1, 4),))
174
175        def construct(self, data):
176            x = self.relu1(data)
177            x = self.reshape(x, (3, 4))
178            x = self.relu2(x)
179            return x
180
181    size = 4
182    context.set_auto_parallel_context(device_num=size, global_rank=0)
183    x = Tensor(np.ones([2, 2, 3]), dtype=ms.float32)
184
185    net = GradWrap(NetWithLoss(Net()))
186    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
187    net.set_auto_parallel()
188    net.set_train()
189    _cell_graph_executor.compile(net, x)
190
191def test_reshape_unexpand_6():
192    class Net(nn.Cell):
193        def __init__(self):
194            super().__init__()
195            self.reshape = P.Reshape()
196            self.relu1 = P.ReLU().shard(((2, 1),))
197            self.relu2 = P.ReLU().shard(((1, 1, 4),))
198
199        def construct(self, data):
200            x = self.relu1(data)
201            x = self.reshape(x, (1, 3, 4))
202            x = self.relu2(x)
203            return x
204
205    size = 4
206    context.set_auto_parallel_context(device_num=size, global_rank=0)
207    x = Tensor(np.ones([4, 3]), dtype=ms.float32)
208
209    net = GradWrap(NetWithLoss(Net()))
210    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
211    net.set_auto_parallel()
212    net.set_train()
213    _cell_graph_executor.compile(net, x)
214
215def test_reshape_unexpand_7():
216    class Net(nn.Cell):
217        def __init__(self, in_channel=3, out_channel=8, axis=1, input_shape=(32, 4, 110, -1),
218                     mul_size=(32, 1, 220, 220)):
219            super().__init__()
220            mul_np = np.full(mul_size, 0.5, dtype=np.float32)
221            self.mul_weight = Parameter(Tensor(mul_np), name="mul_weight")
222            self.mul = P.Mul()
223            self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
224                                  kernel_size=5, has_bias=True, weight_init='ones',
225                                  bias_init='ones', pad_mode='valid')
226            self.conv.conv2d.shard(((8, 1, 1, 1), (1, 1, 1, 1)))
227            self.softmax = nn.Softmax(axis=axis)
228            self.relu = nn.ReLU()
229            self.reshape = P.Reshape()
230            self.input_shape = input_shape
231
232        def construct(self, inputs):
233            x = self.conv(inputs)
234            x = self.softmax(x)
235            x = self.relu(x)
236            x = self.mul(x, self.mul_weight)
237            x = self.reshape(x, self.input_shape)
238            return x
239
240    size = 8
241    context.set_auto_parallel_context(device_num=size, global_rank=0)
242    context.set_auto_parallel_context(parallel_mode="auto_parallel")
243    x = Tensor(np.ones([32, 3, 224, 224]), dtype=ms.float32)
244    net = GradWrap(NetWithLoss(Net()))
245    net.set_auto_parallel()
246    net.set_train()
247    _cell_graph_executor.compile(net, x)
248
249def test_reshape_unexpand_8():
250    class Net(nn.Cell):
251        def __init__(self):
252            super().__init__()
253            self.reshape = P.Reshape()
254            self.mul = P.Mul().shard(((1, 4, 2), (4, 2)))
255            self.mul_weight = Parameter(Tensor(np.ones([128, 96]), dtype=ms.float32), name="weight")
256
257        def construct(self, data):
258            x = self.reshape(self.mul_weight, (1, 128, 96))
259            out = self.mul(x, self.mul_weight)
260            return out
261
262    size = 8
263    context.set_auto_parallel_context(device_num=size, global_rank=0)
264    x = Tensor(np.ones([128, 96]), dtype=ms.float32)
265
266    net = GradWrap(NetWithLoss(Net()))
267    context.set_auto_parallel_context(parallel_mode="auto_parallel")
268    net.set_auto_parallel()
269    net.set_train()
270    _cell_graph_executor.compile(net, x)
271