• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.ops import composite as C
23from mindspore.ops import operations as P
24from tests.ut.python.ops.test_math_ops import VirtualLoss
25
26
27grad_all = C.GradOperation(get_all=True)
28
29
30class NetWithLoss(nn.Cell):
31    def __init__(self, network):
32        super(NetWithLoss, self).__init__()
33        self.loss = VirtualLoss()
34        self.network = network
35
36    def construct(self, x, y, b):
37        predict = self.network(x, y, b)
38        return self.loss(predict)
39
40
41class GradWrap(nn.Cell):
42    def __init__(self, network):
43        super(GradWrap, self).__init__()
44        self.network = network
45
46    def construct(self, x, y, b):
47        return grad_all(self.network)(x, y, b)
48
49
50def compile_net(net, x, y, b):
51    net.set_auto_parallel()
52    net.set_train()
53    _cell_graph_executor.compile(net, x, y, b)
54
55
56def test_matmul_tanh():
57    class Net(nn.Cell):
58        def __init__(self, strategy1, strategy2, strategy3):
59            super().__init__()
60            self.matmul1 = P.MatMul().shard(strategy1)
61            self.matmul2 = P.MatMul().shard(strategy2)
62            self.tanh = P.Tanh().shard(strategy3)
63
64        def construct(self, x, y, b):
65            out = self.tanh(self.matmul1(x, y))
66            out = self.matmul2(out, b)
67            return out
68
69    strategy1 = ((16, 1), (1, 1))
70    strategy2 = ((1, 1), (1, 16))
71    strategy3 = ((4, 4),)
72    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
73    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
74    context.set_auto_parallel_context(device_num=16, global_rank=0)
75
76    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
77    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
78    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
79    compile_net(net, x, y, b)
80
81
82def test_matmul_activation():
83    class Net(nn.Cell):
84        def __init__(self, strategy1, strategy2, strategy3):
85            super().__init__()
86            self.matmul1 = P.MatMul().shard(strategy1)
87            self.matmul2 = P.MatMul().shard(strategy2)
88            self.activation = P.ReLU().shard(strategy3)
89
90        def construct(self, x, y, b):
91            out = self.activation(self.matmul1(x, y))
92            out = self.matmul2(out, b)
93            return out
94
95    strategy1 = ((16, 1), (1, 1))
96    strategy2 = ((1, 1), (1, 16))
97    strategy3 = ((4, 4),)
98    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
99    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
100    context.set_auto_parallel_context(device_num=16, global_rank=0)
101
102    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
103    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
104    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
105    compile_net(net, x, y, b)
106
107
108def test_matmul_softmax():
109    class Net(nn.Cell):
110        def __init__(self, strategy1, strategy2, strategy3):
111            super().__init__()
112            self.matmul1 = P.MatMul().shard(strategy1)
113            self.matmul2 = P.MatMul().shard(strategy2)
114            self.softmax = P.Softmax().shard(strategy3)
115
116        def construct(self, x, y, b):
117            out = self.softmax(self.matmul1(x, y))
118            out = self.matmul2(out, b)
119            return out
120
121    strategy1 = ((16, 1), (1, 1))
122    strategy2 = ((1, 1), (1, 16))
123    strategy3 = ((16, 1),)
124    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
125    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
126    context.set_auto_parallel_context(device_num=16, global_rank=0)
127
128    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
129    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
130    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
131    compile_net(net, x, y, b)
132
133
134def test_matmul_logsoftmax():
135    class Net(nn.Cell):
136        def __init__(self, strategy1, strategy2, strategy3):
137            super().__init__()
138            self.matmul1 = P.MatMul().shard(strategy1)
139            self.matmul2 = P.MatMul().shard(strategy2)
140            self.logsoftmax = P.LogSoftmax().shard(strategy3)
141
142        def construct(self, x, y, b):
143            out = self.logsoftmax(self.matmul1(x, y))
144            out = self.matmul2(out, b)
145            return out
146
147    strategy1 = ((4, 2), (2, 2))
148    strategy2 = ((2, 4), (4, 2))
149    strategy3 = ((16, 1),)
150    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
151    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
152    context.set_auto_parallel_context(device_num=16, global_rank=0)
153
154    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
155    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
156    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
157    compile_net(net, x, y, b)
158
159
160def test_activations():
161    class Net(nn.Cell):
162        def __init__(self, strategy1, strategy2, strategy3):
163            super().__init__()
164            self.matmul1 = P.MatMul().shard(strategy1)
165            self.matmul2 = P.MatMul().shard(strategy2)
166            self.gelu = P.GeLU().shard(strategy3)
167            self.tanh = P.Tanh().shard(strategy3)
168            self.softmax = P.Softmax().shard(strategy3)
169            self.logsoftmax = P.LogSoftmax().shard(strategy3)
170
171        def construct(self, x, y, b):
172            out = self.gelu(self.tanh(self.matmul1(x, y)))
173            out = self.logsoftmax(self.softmax(self.matmul2(out, b)))
174            return out
175
176    strategy1 = ((1, 2), (2, 2))
177    strategy2 = ((2, 2), (2, 1))
178    strategy3 = ((4, 1),)
179    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3)))
180    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
181    context.set_auto_parallel_context(device_num=4, global_rank=0)
182
183    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
184    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
185    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
186    compile_net(net, x, y, b)
187
188
189def test_activations_repeated_calculation():
190    class Net(nn.Cell):
191        def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
192            super().__init__()
193            self.matmul1 = P.MatMul().shard(strategy1)
194            self.matmul2 = P.MatMul().shard(strategy2)
195            self.gelu = P.GeLU().shard(strategy3)
196            self.tanh = P.Tanh().shard(strategy4)
197            self.softmax = P.Softmax().shard(strategy5)
198            self.logsoftmax = P.LogSoftmax().shard(strategy6)
199
200        def construct(self, x, y, b):
201            out = self.gelu(self.tanh(self.matmul1(x, y)))
202            out = self.logsoftmax(self.softmax(self.matmul2(out, b)))
203            return out
204
205    strategy1 = ((2, 4), (4, 8))
206    strategy2 = ((2, 2), (2, 1))
207    strategy3 = ((2, 1),)
208    strategy4 = ((2, 2),)
209    strategy5 = ((4, 1),)
210    strategy6 = ((8, 1),)
211    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
212    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
213    context.set_auto_parallel_context(device_num=64, global_rank=0)
214
215    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
216    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
217    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
218    compile_net(net, x, y, b)
219
220
221def test_activations_axis_tuple():
222    class Net(nn.Cell):
223        def __init__(self, strategy1, strategy2, strategy3, strategy4, strategy5, strategy6):
224            super().__init__()
225            self.matmul1 = P.MatMul().shard(strategy1)
226            self.matmul2 = P.MatMul().shard(strategy2)
227            self.gelu = P.GeLU().shard(strategy3)
228            self.tanh = P.Tanh().shard(strategy4)
229            self.softmax = P.Softmax(axis=(0, 1)).shard(strategy5)
230            self.logsoftmax = P.LogSoftmax().shard(strategy6)
231
232        def construct(self, x, y, b):
233            out = self.gelu(self.tanh(self.matmul1(x, y)))
234            out = self.logsoftmax(self.softmax(self.matmul2(out, b)))
235            return out
236
237    strategy1 = ((2, 4), (4, 8))
238    strategy2 = ((2, 2), (2, 1))
239    strategy3 = ((2, 1),)
240    strategy4 = ((2, 2),)
241    strategy5 = ((1, 1),)
242    strategy6 = ((8, 1),)
243    net = GradWrap(NetWithLoss(Net(strategy1, strategy2, strategy3, strategy4, strategy5, strategy6)))
244    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
245    context.set_auto_parallel_context(device_num=64, global_rank=0)
246
247    x = Tensor(np.ones([128, 32]), dtype=ms.float32)
248    y = Tensor(np.ones([32, 64]), dtype=ms.float32)
249    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
250    compile_net(net, x, y, b)
251