• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import mindspore as ms
17import mindspore.context as context
18from mindspore import Tensor, Parameter
19import mindspore.nn as nn
20from mindspore.common.api import _cell_graph_executor
21from mindspore.nn import TrainOneStepCell, Momentum
22from mindspore.ops import operations as P
23
24class Net(nn.Cell):
25    def __init__(self, wi, wo, stra1=None, stra2=None, stra3=None, stra4=None,
26                 stra5=None, stra6=None):
27        super(Net, self).__init__()
28        self.relu = P.ReLU().shard(stra1)
29        self.transpose = P.Transpose().shard(stra2)
30        self.wi = Parameter(wi, "wi")
31        self.batch_mm = P.BatchMatMul().shard(stra3)
32        self.wo = Parameter(wo, "wo")
33        self.batch_mm2 = P.BatchMatMul().shard(stra4)
34        self.transpose2 = P.Transpose().shard(stra5)
35        self.relu2 = P.ReLU().shard(stra6)
36        self.reshape = P.Reshape()
37        self.reshape2 = P.Reshape()
38
39    def construct(self, x):
40        output = self.relu(x)
41        trans_out = self.transpose(output, (2, 0, 3, 1))
42        output = self.reshape(trans_out,
43                              (trans_out.shape[0], trans_out.shape[1]*trans_out.shape[2], trans_out.shape[3]))
44        output = self.batch_mm(output, self.wi)
45        output = self.batch_mm2(output, self.wo)
46        output = self.reshape2(output, trans_out.shape)
47        output = self.transpose2(output, (1, 3, 0, 2))
48        output = self.relu2(output)
49        return output
50
51_x = Tensor(np.ones([32, 16, 48, 128]), dtype=ms.float32)
52_wi = Tensor(np.ones([48, 16, 64]), dtype=ms.float32)
53_wo = Tensor(np.ones([48, 64, 16]), dtype=ms.float32)
54
55
56def compile_net(net):
57    context.set_context(mode=context.GRAPH_MODE)
58    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
59    train_net = TrainOneStepCell(net, optimizer)
60    train_net.set_auto_parallel()
61    train_net.set_train()
62    _cell_graph_executor.compile(train_net, _x)
63    context.reset_auto_parallel_context()
64
65
66def test_batchmm():
67    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, enable_alltoall=True,
68                                      global_rank=0)
69    stra1 = ((8, 1, 1, 1),)
70    stra2 = ((8, 1, 1, 1),)
71    stra3 = ((8, 1, 1), (8, 1, 1))
72    stra4 = ((8, 1, 1), (8, 1, 1))
73    stra5 = ((8, 1, 1, 1),)
74    stra6 = ((8, 1, 1, 1),)
75    net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6)
76    compile_net(net)
77
78
79def test_batchmm2():
80    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", enable_alltoall=True,
81                                      device_num=32, global_rank=0)
82    stra1 = ((4, 1, 1, 1),)
83    stra2 = ((4, 1, 1, 1),)
84    stra3 = ((4, 1, 1), (4, 1, 8))
85    stra4 = ((4, 1, 8), (4, 8, 1))
86    stra5 = ((4, 1, 1, 1),)
87    stra6 = ((4, 1, 1, 1),)
88    net = Net(_wi, _wo, stra1=stra1, stra2=stra2, stra3=stra3, stra4=stra4, stra5=stra5, stra6=stra6)
89    compile_net(net)
90