1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(Cell): 26 def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, 27 strategy1=None, strategy2=None): 28 super().__init__() 29 self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size, 30 pad_mode=pad_mode, stride=stride).shard(strategy1) 31 self.neg = P.Neg().shard(strategy2) 32 self.weight = Parameter(conv2d_weight, "w1") 33 34 def construct(self, x, b): 35 out = self.conv2d_transpose(x, self.weight, (32, 16, 8, 8)) 36 out = self.neg(out) 37 return out 38 39 40class Net2(Cell): 41 def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, 42 strategy1=None, strategy2=None): 43 super().__init__() 44 self.conv2d_transpose = P.Conv2DTranspose(out_channel=out_channel, kernel_size=kernel_size, 45 pad_mode=pad_mode, stride=stride).shard(strategy1) 46 self.neg = P.Neg().shard(strategy2) 47 self.weight = Parameter(conv2d_weight, "w1") 48 49 def construct(self, x, b): 50 out = self.conv2d_transpose(x, self.weight, (32, 16, 16, 16)) 51 out = self.neg(out) 52 return out 53 54 55_x = Tensor(np.ones([32, 8, 8, 8]), dtype=ms.float32) 56_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) 57_w2 = Tensor(np.ones([8, 16, 4, 4]), dtype=ms.float32) 58_w3 = Tensor(np.ones([8, 16, 10, 10]), dtype=ms.float32) 59_w4 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) 60_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) 61 62 63def compile_net(net): 64 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 65 train_net = TrainOneStepCell(net, optimizer) 66 train_net.set_auto_parallel() 67 train_net.set_train() 68 _cell_graph_executor.compile(train_net, _x, _b) 69 context.reset_auto_parallel_context() 70 71 72def test_conv2d_transpose_data_parallel(): 73 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 74 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 75 strategy2 = ((8, 1, 1, 1),) 76 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 77 compile_net(net) 78 79 80def test_conv2d_transpose_model_parallel1(): 81 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 82 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 83 strategy2 = ((8, 1, 1, 1),) 84 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 85 compile_net(net) 86 87 88def test_conv2d_transpose_model_parallel2(): 89 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 90 strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) 91 strategy2 = ((2, 1, 1, 4),) 92 net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, 93 strategy1=strategy1, strategy2=strategy2) 94 compile_net(net) 95 96 97def test_conv2d_transpose_model_parallel3(): 98 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 99 strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) 100 strategy2 = ((2, 2, 1, 4),) 101 net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, 102 strategy1=strategy1, strategy2=strategy2) 103 compile_net(net) 104 105 106def test_conv2d_transpose_all_rank_no_need_overlap(): 107 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 108 strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) 109 strategy2 = ((2, 2, 1, 4),) 110 net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2, 111 strategy1=strategy1, strategy2=strategy2) 112 compile_net(net) 113 114 115def test_conv2d_transpose_split_h_or_w_in_pad_mode(): 116 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 117 strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) 118 strategy2 = ((2, 2, 1, 4),) 119 net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="pad", stride=2, 120 strategy1=strategy1, strategy2=strategy2) 121 with pytest.raises(RuntimeError): 122 compile_net(net) 123 124 125def test_conv2d_transpose_split_h_in_same_mode(): 126 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 127 strategy1 = ((2, 2, 4, 1), (2, 1, 1, 1)) 128 strategy2 = ((2, 2, 1, 4),) 129 net = Net2(_w1, out_channel=8, kernel_size=(2, 2), pad_mode="same", stride=2, 130 strategy1=strategy1, strategy2=strategy2) 131 with pytest.raises(RuntimeError): 132 compile_net(net) 133 134 135def test_conv2d_transpose_overlap_size_too_large(): 136 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 137 strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) 138 strategy2 = ((1, 1, 1, 8),) 139 net = Net2(_w3, out_channel=8, kernel_size=(10, 10), pad_mode="same", stride=2, 140 strategy1=strategy1, strategy2=strategy2) 141 with pytest.raises(RuntimeError): 142 compile_net(net) 143 144 145def test_conv2d_transpose_overlap_size_too_large2(): 146 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 147 strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) 148 strategy2 = ((2, 2, 1, 4),) 149 net = Net2(_w2, out_channel=8, kernel_size=(4, 4), pad_mode="same", stride=2, 150 strategy1=strategy1, strategy2=strategy2) 151 with pytest.raises(RuntimeError): 152 compile_net(net) 153 154 155def test_conv2d_transpose_rank0_no_need_overlap(): 156 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0) 157 strategy1 = ((2, 2, 1, 4), (2, 1, 1, 1)) 158 strategy2 = ((2, 2, 1, 4),) 159 net = Net2(_w4, out_channel=8, kernel_size=(3, 3), pad_mode="same", stride=2, 160 strategy1=strategy1, strategy2=strategy2) 161 with pytest.raises(RuntimeError): 162 compile_net(net) 163