• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import mindspore as ms
19from mindspore import context, Tensor, Parameter
20from mindspore.common.api import _cell_graph_executor
21from mindspore.nn import Cell, TrainOneStepCell, Momentum
22from mindspore.ops import operations as P
23
24
25class Net(Cell):
26    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
27                 strategy1=None, strategy2=None):
28        super().__init__()
29        self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
30                               pad_mode=pad_mode, stride=stride).shard(strategy1)
31        self.conv2d_weight = Parameter(conv2d_weight, "w1")
32        self.max_pool = P.MaxPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
33
34    def construct(self, x, b):
35        out = self.conv2d(x, self.conv2d_weight)
36        out = self.max_pool(out)
37        return out
38
39
40class Net2(Cell):
41    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides,
42                 strategy1=None, strategy2=None):
43        super().__init__()
44        self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size,
45                               pad_mode=pad_mode, stride=stride).shard(strategy1)
46        self.conv2d_weight = Parameter(conv2d_weight, "w1")
47        self.avg_pool = P.AvgPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2)
48
49    def construct(self, x, b):
50        out = self.conv2d(x, self.conv2d_weight)
51        out = self.avg_pool(out)
52        return out
53
54
55_x0 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32)
56_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
57_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
58_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
59
60
61def compile_net(net, inputs=_x):
62    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
63    train_net = TrainOneStepCell(net, optimizer)
64    train_net.set_auto_parallel()
65    train_net.set_train()
66    _cell_graph_executor.compile(train_net, inputs, _b)
67    context.reset_auto_parallel_context()
68
69
70def test_maxpool_data_parallel():
71    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
72    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
73    strategy2 = ((8, 1, 1, 1),)
74    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
75              strategy1=strategy1, strategy2=strategy2)
76    compile_net(net)
77
78
79def test_maxpool_model_parallel1():
80    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
81    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
82    strategy2 = ((2, 1, 2, 2),)
83    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
84              strategy1=strategy1, strategy2=strategy2)
85    compile_net(net)
86
87
88def test_maxpool_model_parallel2():
89    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
90    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
91    strategy2 = ((2, 1, 2, 2),)
92    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
93              strategy1=strategy1, strategy2=strategy2)
94    compile_net(net)
95
96
97def test_maxpool_auto_parallel():
98    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
99    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
100    compile_net(net)
101
102
103def test_maxpool_output_is_not_divisible_by_strategy_w_dimension():
104    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
105    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
106    strategy2 = ((1, 1, 1, 8),)
107    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
108              strategy1=strategy1, strategy2=strategy2)
109    with pytest.raises(RuntimeError):
110        compile_net(net)
111
112
113def test_maxpool_output_is_not_divisible_by_strategy_h_dimension():
114    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
115    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
116    strategy2 = ((1, 1, 8, 1),)
117    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
118              strategy1=strategy1, strategy2=strategy2)
119    with pytest.raises(RuntimeError):
120        compile_net(net)
121
122
123def test_maxpool_shard_h_and_kernel_size_larger_than_stride():
124    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
125    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
126    strategy2 = ((1, 1, 2, 1),)
127    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
128              strategy1=strategy1, strategy2=strategy2)
129    with pytest.raises(RuntimeError):
130        compile_net(net)
131
132
133def test_maxpool_shard_w_and_kernel_size_larger_than_stride():
134    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
135    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
136    strategy2 = ((1, 1, 1, 2),)
137    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2,
138              strategy1=strategy1, strategy2=strategy2)
139    with pytest.raises(RuntimeError):
140        compile_net(net)
141
142
143def test_maxpool_shard_h_and_input_slice_is_not_divisible_by_stride():
144    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
145    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
146    strategy2 = ((1, 1, 2, 1),)
147    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
148              strategy1=strategy1, strategy2=strategy2)
149    with pytest.raises(RuntimeError):
150        compile_net(net, inputs=_x0)
151
152
153def test_maxpool_shard_w_and_input_slice_is_not_divisible_by_stride():
154    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
155    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
156    strategy2 = ((1, 1, 2, 1),)
157    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3,
158              strategy1=strategy1, strategy2=strategy2)
159    with pytest.raises(RuntimeError):
160        compile_net(net, inputs=_x0)
161
162
163def test_avgpool_data_parallel():
164    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
165    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
166    strategy2 = ((8, 1, 1, 1),)
167    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
168               strategy1=strategy1, strategy2=strategy2)
169    compile_net(net)
170
171
172def test_avgpool_model_parallel1():
173    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
174    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
175    strategy2 = ((2, 1, 2, 2),)
176    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2,
177               strategy1=strategy1, strategy2=strategy2)
178    compile_net(net)
179
180
181def test_avgpool_model_parallel2():
182    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
183    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
184    strategy2 = ((2, 1, 2, 2),)
185    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4,
186               strategy1=strategy1, strategy2=strategy2)
187    compile_net(net)
188
189
190def test_avgpool_auto_parallel():
191    context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
192    net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4)
193    compile_net(net)
194