• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16
17import mindspore as ms
18import mindspore.nn as nn
19from mindspore import Tensor, Parameter
20from mindspore import context
21from mindspore.common.api import _cell_graph_executor
22from mindspore.nn import TrainOneStepCell
23from mindspore.nn.optim import Momentum, LARS
24from mindspore.ops import operations as P
25
26
27class NetWithLoss(nn.Cell):
28    def __init__(self, network, strategy3):
29        super(NetWithLoss, self).__init__()
30        self.loss = P.SoftmaxCrossEntropyWithLogits().shard(strategy3)
31        self.network = network
32
33    def construct(self, x, b):
34        predict = self.network(x)
35        return self.loss(predict, b)[0]
36
37
38def compile_net(net, x, b):
39    net.set_auto_parallel()
40    net.set_train()
41    _cell_graph_executor.compile(net, x, b)
42
43
44def test_momentum():
45    class Net(nn.Cell):
46        def __init__(self, strategy1, strategy2, weight):
47            super().__init__()
48            self.weight = Parameter(weight, "w1")
49            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
50            self.relu = P.ReLU().shard(strategy2)
51
52        def construct(self, x):
53            out = self.matmul(x, self.weight)
54            out = self.relu(out)
55            return out
56
57    context.set_auto_parallel_context(device_num=4, global_rank=0)
58    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
59    strategy1 = ((2, 1), (2, 1))
60    strategy2 = ((4, 1),)
61    strategy3 = ((4, 1), (4, 1))
62
63    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
64    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
65    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
66
67    net = Net(strategy1, strategy2, weight)
68
69    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
70
71    net_with_loss = NetWithLoss(net, strategy3)
72
73    train_net = TrainOneStepCell(net_with_loss, optimizer)
74
75    compile_net(train_net, x, b)
76
77
78def test_momentum_with_loss_scale():
79    class Net(nn.Cell):
80        def __init__(self, strategy1, strategy2, weight):
81            super().__init__()
82            self.weight = Parameter(weight, "w1")
83            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
84            self.relu = P.ReLU().shard(strategy2)
85
86        def construct(self, x):
87            out = self.matmul(x, self.weight)
88            out = self.relu(out)
89            return out
90
91    context.set_auto_parallel_context(device_num=4, global_rank=0)
92    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
93    strategy1 = ((2, 1), (2, 1))
94    strategy2 = ((4, 1),)
95    strategy3 = ((4, 1), (4, 1))
96
97    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
98    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
99    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
100
101    net = Net(strategy1, strategy2, weight)
102
103    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9, loss_scale=0.5)
104
105    net_with_loss = NetWithLoss(net, strategy3)
106
107    train_net = TrainOneStepCell(net_with_loss, optimizer)
108
109    compile_net(train_net, x, b)
110
111
112def test_momentum_with_dynamic_lr():
113    class Net(nn.Cell):
114        def __init__(self, strategy1, strategy2, weight):
115            super().__init__()
116            self.weight = Parameter(weight, "w1")
117            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
118            self.relu = P.ReLU().shard(strategy2)
119
120        def construct(self, x):
121            out = self.matmul(x, self.weight)
122            out = self.relu(out)
123            return out
124
125    context.set_auto_parallel_context(device_num=4, global_rank=0)
126    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
127    strategy1 = ((2, 1), (2, 1))
128    strategy2 = ((4, 1),)
129    strategy3 = ((4, 1), (4, 1))
130
131    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
132    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
133    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
134
135    net = Net(strategy1, strategy2, weight)
136
137    lr = Tensor(np.ones([6]), dtype=ms.float32)
138    optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
139
140    net_with_loss = NetWithLoss(net, strategy3)
141
142    train_net = TrainOneStepCell(net_with_loss, optimizer)
143
144    compile_net(train_net, x, b)
145
146
147def test_momentum_with_loss_scale_and_dynamic_lr():
148    class Net(nn.Cell):
149        def __init__(self, strategy1, strategy2, weight):
150            super().__init__()
151            self.weight = Parameter(weight, "w1")
152            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
153            self.relu = P.ReLU().shard(strategy2)
154
155        def construct(self, x):
156            out = self.matmul(x, self.weight)
157            out = self.relu(out)
158            return out
159
160    context.set_auto_parallel_context(device_num=4, global_rank=0)
161    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
162
163    strategy1 = ((2, 1), (2, 1))
164    strategy2 = ((4, 1),)
165    strategy3 = ((4, 1), (4, 1))
166
167    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
168    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
169    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
170
171    net = Net(strategy1, strategy2, weight)
172
173    lr = Tensor(np.ones([6]), dtype=ms.float32)
174    optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9, loss_scale=0.5)
175
176    net_with_loss = NetWithLoss(net, strategy3)
177
178    train_net = TrainOneStepCell(net_with_loss, optimizer)
179
180    compile_net(train_net, x, b)
181
182
183def test_lars():
184    class Net(nn.Cell):
185        def __init__(self, strategy1, strategy2, weight):
186            super().__init__()
187            self.weight = Parameter(weight, "w1")
188            self.matmul = P.MatMul(transpose_a=False, transpose_b=True).shard(strategy1)
189            self.relu = P.ReLU().shard(strategy2)
190
191        def construct(self, x):
192            out = self.matmul(x, self.weight)
193            out = self.relu(out)
194            return out
195
196    context.set_auto_parallel_context(device_num=4, global_rank=0)
197    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
198    strategy1 = ((2, 1), (2, 1))
199    strategy2 = ((4, 1),)
200    strategy3 = ((4, 1), (4, 1))
201
202    x = Tensor(np.ones([64, 32]), dtype=ms.float32)
203    weight = Tensor(np.ones([64, 32]), dtype=ms.float32)
204    b = Tensor(np.ones([64, 64]), dtype=ms.float32)
205
206    net = Net(strategy1, strategy2, weight)
207
208    lr = Tensor(np.ones([6]), dtype=ms.float32)
209    sgd = Momentum(net.trainable_params(), lr, 0.9)
210    optimizer = LARS(sgd, epsilon=1e-08, coefficient=0.02,
211                     lars_filter=lambda x: 'bn' not in x.name)
212    net_with_loss = NetWithLoss(net, strategy3)
213    train_net = TrainOneStepCell(net_with_loss, optimizer)
214
215    compile_net(train_net, x, b)
216