• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16
17import mindspore as ms
18from mindspore import context, Tensor, Parameter
19from mindspore.nn import Cell, Momentum
20from mindspore.ops import operations as P
21from mindspore.train import Model
22from tests.dataset_mock import MindData
23
24
25class Dataset(MindData):
26    def __init__(self, predict, label, length=3):
27        super(Dataset, self).__init__(size=length)
28        self.predict = predict
29        self.label = label
30        self.index = 0
31        self.length = length
32
33    def __iter__(self):
34        return self
35
36    def __next__(self):
37        if self.index >= self.length:
38            raise StopIteration
39        self.index += 1
40        return self.predict, self.label
41
42    def reset(self):
43        self.index = 0
44
45
46class Net(Cell):
47    def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True):
48        super().__init__()
49        self.concat = P.Concat(axis=0).shard(strategy1)
50        if is_parameter:
51            self.weight = Parameter(weight, "w1")
52        else:
53            self.weight = weight
54        self.mul = P.Mul().shard(strategy2)
55        self.weight2 = Parameter(weight2, "w2")
56
57    def construct(self, x, b):
58        out = self.concat((self.weight, self.weight2))
59        out = self.mul(x, out)
60        return out
61
62
63class Net2(Cell):
64    def __init__(self, weight, strategy1=None, strategy2=None, axis=0):
65        super().__init__()
66        self.mul = P.Mul().shard(strategy1)
67        self.concat = P.Concat(axis=axis).shard(strategy2)
68        self.weight = Parameter(weight, "w")
69
70    def construct(self, x, b):
71        out = self.mul(x, x)
72        out = self.concat((out, self.weight))
73        return out
74
75
76class Net3(Cell):
77    def __init__(self, weight, weight2, weight3, strategy1=None, strategy2=None, is_parameter=True):
78        super().__init__()
79        self.concat = P.Concat(axis=0).shard(strategy1)
80        if is_parameter:
81            self.weight = Parameter(weight, "w1")
82        else:
83            self.weight = weight
84        self.mul = P.Mul().shard(strategy2)
85        self.weight2 = Parameter(weight2, "w2")
86        self.weight3 = Parameter(weight3, "w3")
87
88    def construct(self, x, b):
89        out = self.concat((self.weight, self.weight2, self.weight3))
90        out = self.mul(x, out)
91        return out
92
93
94_x = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
95_b = Tensor(np.ones([16, 64, 32, 32]), dtype=ms.int32)
96_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32)
97_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32)
98_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32)
99
100w1 = Tensor(np.ones([48, 64, 32]), dtype=ms.float32)
101w2 = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
102w3 = Tensor(np.ones([64, 64, 32]), dtype=ms.float32)
103
104
105def compile_net(net):
106    learning_rate = 0.1
107    momentum = 0.9
108    epoch_size = 2
109    dataset = Dataset(_x, _b)
110    opt = Momentum(net.trainable_params(), learning_rate, momentum)
111    model = Model(net, optimizer=opt, amp_level="O2")
112    model.train(epoch_size, dataset, dataset_sink_mode=False)
113    context.reset_auto_parallel_context()
114
115
116def test_concat_parameter():
117    context.set_auto_parallel_context(
118        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
119    strategy1 = ((1, 4, 2), (1, 4, 2))
120    strategy2 = ((1, 4, 2), (1, 4, 2))
121    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
122    compile_net(net)
123
124
125def test_concat_parameter_no_full_split():
126    context.set_auto_parallel_context(
127        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
128    strategy1 = ((1, 2, 2), (1, 2, 2))
129    strategy2 = ((1, 4, 2), (1, 4, 2))
130    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
131    compile_net(net)
132
133
134def test_concat_tensor_and_parameter():
135    context.set_auto_parallel_context(
136        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
137    strategy1 = ((1, 2, 2), (1, 2, 2))
138    strategy2 = ((1, 4, 2), (1, 4, 2))
139    net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
140    compile_net(net)
141
142
143def test_concat_output():
144    context.set_auto_parallel_context(
145        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
146    strategy1 = ((2, 2, 2), (2, 2, 2))
147    strategy2 = ((1, 4, 2), (1, 4, 2))
148    net = Net2(_w1, strategy1, strategy2)
149    compile_net(net)
150
151
152def test_concat_output_no_full_split():
153    context.set_auto_parallel_context(
154        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
155    strategy1 = ((2, 2, 2), (2, 2, 2))
156    strategy2 = ((1, 2, 2), (1, 2, 2))
157    net = Net2(_w1, strategy1, strategy2)
158    compile_net(net)
159
160
161def test_concat_no_strategy():
162    context.set_auto_parallel_context(
163        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
164    strategy1 = ((2, 2, 2), (2, 2, 2))
165    strategy2 = None
166    net = Net2(_w3, strategy1, strategy2, axis=1)
167    compile_net(net)
168
169
170def test_concat_auto_parallel():
171    context.set_auto_parallel_context(
172        parallel_mode="auto_parallel", device_num=8, global_rank=0)
173    net = Net2(_w2)
174    compile_net(net)
175
176
177def test_concat_auto_parallel2():
178    context.set_auto_parallel_context(
179        parallel_mode="auto_parallel", device_num=8, global_rank=0)
180    strategy1 = None
181    strategy2 = None
182    net = Net2(_w3, strategy1, strategy2, axis=1)
183    compile_net(net)
184
185
186def test_concat_auto_parallel_3_tensor():
187    context.set_auto_parallel_context(
188        parallel_mode="auto_parallel", device_num=8, global_rank=0)
189    net = Net3(w1, w2, w3)
190    compile_net(net)
191