• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import numpy as np
16import pytest
17
18import mindspore as ms
19from mindspore import context, Tensor, Parameter
20from mindspore.common.api import _cell_graph_executor
21from mindspore.nn import Cell, TrainOneStepCell, Momentum
22from mindspore.ops import operations as P
23
24
25class Net(Cell):
26    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
27                 strategy1=None, strategy2=None):
28        super().__init__()
29        self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
30                                                  pad_mode=pad_mode, stride=stride).shard(strategy1)
31        self.neg = P.Neg().shard(strategy2)
32        self.weight = Parameter(conv2d_weight, "w1")
33
34    def construct(self, x, b):
35        out = self.conv2d_transpose(x, self.weight, (32, 16, 8, 8))
36        out = self.neg(out)
37        return out
38
39
40class Net2(Cell):
41    def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride,
42                 strategy1=None, strategy2=None):
43        super().__init__()
44        self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size,
45                                                  pad_mode=pad_mode, stride=stride).shard(strategy1)
46        self.neg = P.Neg().shard(strategy2)
47        self.weight = Parameter(conv2d_weight, "w1")
48
49    def construct(self, x, b):
50        out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16))
51        out = self.neg(out)
52        return out
53
54
55_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32)
56_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32)
57_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32)
58_w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32)
59_w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32)
60_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32)
61
62
63def compile_net(net):
64    optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
65    train_net = TrainOneStepCell(net, optimizer)
66    train_net.set_auto_parallel()
67    train_net.set_train()
68    _cell_graph_executor.compile(train_net, _x, _b)
69    context.reset_auto_parallel_context()
70
71
72def test_conv2d_transpose_data_parallel():
73    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
74    strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1))
75    strategy2 = ((8, 1, 1, 1),)
76    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
77    compile_net(net)
78
79
80def test_conv2d_transpose_model_parallel1():
81    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
82    strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1))
83    strategy2 = ((8, 1, 1, 1),)
84    net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2)
85    compile_net(net)
86
87
88def test_conv2d_transpose_model_parallel2():
89    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
90    strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1))
91    strategy2 = ((2, 1, 1, 4),)
92    net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
93               strategy1=strategy1, strategy2=strategy2)
94    compile_net(net)
95
96
97def test_conv2d_transpose_model_parallel3():
98    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
99    strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
100    strategy2 = ((2, 2, 1, 4),)
101    net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
102               strategy1=strategy1, strategy2=strategy2)
103    compile_net(net)
104
105
106def test_conv2d_transpose_all_rank_no_need_overlap():
107    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
108    strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
109    strategy2 = ((2, 2, 1, 4),)
110    net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
111               strategy1=strategy1, strategy2=strategy2)
112    compile_net(net)
113
114
115def test_conv2d_transpose_split_h_or_w_in_pad_mode():
116    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
117    strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
118    strategy2 = ((2, 2, 1, 4),)
119    net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="pad", stride=2,
120               strategy1=strategy1, strategy2=strategy2)
121    with pytest.raises(RuntimeError):
122        compile_net(net)
123
124
125def test_conv2d_transpose_split_h_in_same_mode():
126    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
127    strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1))
128    strategy2 = ((2, 2, 1, 4),)
129    net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2,
130               strategy1=strategy1, strategy2=strategy2)
131    with pytest.raises(RuntimeError):
132        compile_net(net)
133
134
135def test_conv2d_transpose_overlap_size_too_large():
136    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
137    strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
138    strategy2 = ((1, 1, 1, 8),)
139    net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2,
140               strategy1=strategy1, strategy2=strategy2)
141    with pytest.raises(RuntimeError):
142        compile_net(net)
143
144
145def test_conv2d_transpose_overlap_size_too_large2():
146    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
147    strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1))
148    strategy2 = ((2, 2, 1, 4),)
149    net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2,
150               strategy1=strategy1, strategy2=strategy2)
151    with pytest.raises(RuntimeError):
152        compile_net(net)
153
154
155def test_conv2d_transpose_rank0_no_need_overlap():
156    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
157    strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1))
158    strategy2 = ((2, 2, 1, 4),)
159    net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2,
160               strategy1=strategy1, strategy2=strategy2)
161    with pytest.raises(RuntimeError):
162        compile_net(net)
163