• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore as ms
18from mindspore import context, Tensor, Parameter
19from mindspore.common.api import _cell_graph_executor
20from mindspore.nn import Cell, TrainOneStepCell, Momentum
21from mindspore.ops import operations as P
22
23class Net(Cell):
24    def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True):
25        super().__init__()
26        self.concat = P.Concat(axis=0).shard(strategy1)
27        if is_parameter:
28            self.weight = Parameter(weight, "w1")
29        else:
30            self.weight = weight
31        self.mul = P.Mul().shard(strategy2)
32        self.weight2 = Parameter(weight2, "w2")
33
34    def construct(self, x, b):
35        out = self.concat((self.weight, self.weight2))
36        out = self.mul(x, out)
37        return out
38
39
40class Net2(Cell):
41    def __init__(self, weight, strategy1=None, strategy2=None, axis=0):
42        super().__init__()
43        self.mul = P.Mul().shard(strategy1)
44        self.concat = P.Concat(axis=axis).shard(strategy2)
45        self.weight = Parameter(weight, "w")
46
47    def construct(self, x, b):
48        out = self.mul(x, b)
49        out = self.concat((out, self.weight))
50        return out
51
52
53class Net3(Cell):
54    def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True):
55        super().__init__()
56        self.concat = P.Concat(axis=0).shard(strategy1)
57        if is_parameter:
58            self.weight = Parameter(weight, "w1")
59        else:
60            self.weight = weight
61        self.mul = P.Mul().shard(strategy2)
62        self.weight2 = Parameter(weight2, "w2")
63        self.weight3 = Parameter(weight3, "w3")
64
65    def construct(self, x, b):
66        out = self.concat((self.weight, self.weight2, self.weight3))
67        out = self.mul(x, out)
68        return out
69
70
71_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
72_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32)
73_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32)
74_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32)
75_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
76
77w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32)
78w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
79w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
80
81
82def compile_net(net):
83    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
84    train_net = TrainOneStepCell(net, optimizer)
85    train_net.set_auto_parallel()
86    train_net.set_train()
87    _cell_graph_executor.compile(train_net, _x, _b)
88    context.reset_auto_parallel_context()
89
90
91def test_concat_parameter():
92    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
93    strategy1 = ((1, 4, 2), (1, 4, 2))
94    strategy2 = ((1, 4, 2), (1, 4, 2))
95    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
96    compile_net(net)
97
98
99def test_concat_parameter_no_full_split():
100    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
101    strategy1 = ((1, 2, 2), (1, 2, 2))
102    strategy2 = ((1, 4, 2), (1, 4, 2))
103    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
104    compile_net(net)
105
106
107def test_concat_tensor_and_parameter():
108    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
109    strategy1 = ((1, 2, 2), (1, 2, 2))
110    strategy2 = ((1, 4, 2), (1, 4, 2))
111    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
112    compile_net(net)
113
114
115def test_concat_output():
116    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
117    strategy1 = ((2, 2, 2), (2, 2, 2))
118    strategy2 = ((1, 4, 2), (1, 4, 2))
119    net = Net2(_w1, strategy1, strategy2)
120    compile_net(net)
121
122
123def test_concat_output_no_full_split():
124    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
125    strategy1 = ((2, 2, 2), (2, 2, 2))
126    strategy2 = ((1, 2, 2), (1, 2, 2))
127    net = Net2(_w1, strategy1, strategy2)
128    compile_net(net)
129
130
131def test_concat_no_strategy():
132    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
133    strategy1 = ((2, 2, 2), (2, 2, 2))
134    strategy2 = None
135    net = Net2(_w3, strategy1, strategy2, axis=1)
136    compile_net(net)
137
138
139def test_concat_auto_parallel():
140    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
141    net = Net2(_w2)
142    compile_net(net)
143
144
145def test_concat_auto_parallel2():
146    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
147    strategy1 = None
148    strategy2 = None
149    net = Net2(_w3, strategy1, strategy2, axis=1)
150    compile_net(net)
151
152
153def test_concat_auto_parallel_3_tensor():
154    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
155    net = Net3(w1, w2, w3)
156    compile_net(net)
157