1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(Cell): 26 def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, dilation=1, group=1, 27 strategy1=None, strategy2=None): 28 super().__init__() 29 self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, 30 pad_mode=pad_mode, stride=stride, dilation=dilation, group=group).shard(strategy1) 31 self.neg = P.Neg().shard(strategy2) 32 self.conv2d_weight = Parameter(conv2d_weight, "w1") 33 34 def construct(self, x, b): 35 out = self.conv2d(x, self.conv2d_weight) 36 out = self.neg(out) 37 return out 38 39 40_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) 41_x2 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32) 42_w0 = Tensor(np.ones([8, 16, 1, 1]), dtype=ms.float32) 43_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) 44_w2 = Tensor(np.ones([8, 16, 3, 3]), dtype=ms.float32) 45_w3 = Tensor(np.ones([8, 16, 5, 5]), dtype=ms.float32) 46_w4 = Tensor(np.ones([8, 8, 2, 2]), dtype=ms.float32) 47_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) 48 49 50def compile_net(net, input_x=_x): 51 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 52 train_net = TrainOneStepCell(net, optimizer) 53 train_net.set_auto_parallel() 54 train_net.set_train() 55 _cell_graph_executor.compile(train_net, input_x, _b) 56 context.reset_auto_parallel_context() 57 58 59def test_conv2d_data_parallel(): 60 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 61 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 62 strategy2 = ((8, 1, 1, 1),) 63 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 64 compile_net(net) 65 66 67def test_conv2d_data_parallel_invalid_stride(): 68 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 69 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 70 strategy2 = ((8, 1, 1, 1),) 71 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=(2, 2, 1, 1), 72 strategy1=strategy1, strategy2=strategy2) 73 with pytest.raises(RuntimeError): 74 compile_net(net) 75 76 77def test_conv2d_data_parallel_dilation(): 78 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 79 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 80 strategy2 = ((8, 1, 1, 1),) 81 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2, 82 strategy1=strategy1, strategy2=strategy2) 83 compile_net(net) 84 85 86def test_conv2d_data_parallel_group(): 87 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 88 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 89 strategy2 = ((8, 1, 1, 1),) 90 net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2, 91 strategy1=strategy1, strategy2=strategy2) 92 compile_net(net) 93 94 95def test_conv2d_model_parallel1(): 96 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 97 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 98 strategy2 = ((8, 1, 1, 1),) 99 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 100 compile_net(net) 101 102 103def test_conv2d_model_parallel_dilation(): 104 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 105 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 106 strategy2 = ((8, 1, 1, 1),) 107 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, dilation=2, 108 strategy1=strategy1, strategy2=strategy2) 109 with pytest.raises(RuntimeError): 110 compile_net(net) 111 112 113def test_conv2d_model_parallel_group(): 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 116 strategy2 = ((8, 1, 1, 1),) 117 net = Net(_w4, out_channel=8, kernel_size=2, pad_mode="same", stride=1, group=2, 118 strategy1=strategy1, strategy2=strategy2) 119 with pytest.raises(RuntimeError): 120 compile_net(net) 121 122 123def test_conv2d_model_parallel2(): 124 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) 125 strategy1 = ((2, 2, 2, 2), (2, 2, 1, 1)) 126 strategy2 = ((32, 1, 1, 1),) 127 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) 128 compile_net(net) 129 130 131def test_conv2d_model_parallel3(): 132 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 133 strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) 134 strategy2 = ((2, 1, 1, 4),) 135 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 136 compile_net(net) 137 138 139def test_conv2d_auto_parallel(): 140 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 141 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1) 142 compile_net(net) 143 144 145def test_conv2d_model_parallel4(): 146 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) 147 strategy1 = ((2, 2, 1, 4), (2, 2, 1, 1)) 148 strategy2 = ((2, 2, 1, 4),) 149 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 150 compile_net(net) 151 152 153def test_conv2d_left_and_right_no_need_to_send(): 154 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 155 strategy1 = ((2, 1, 1, 4), (1, 1, 1, 1)) 156 strategy2 = ((2, 1, 1, 4),) 157 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) 158 with pytest.raises(RuntimeError): 159 compile_net(net) 160 161 162def test_conv2d_kernel_size_larger_than_stride_and_split_h(): 163 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=32, global_rank=0) 164 strategy1 = ((2, 2, 4, 1), (2, 2, 1, 1)) 165 strategy2 = ((2, 2, 4, 1),) 166 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 167 with pytest.raises(RuntimeError): 168 compile_net(net) 169 170 171def test_conv2d_valid_mode_kernel_size_larger_than_stride(): 172 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 173 strategy1 = ((2, 1, 1, 2), (1, 1, 1, 1)) 174 strategy2 = ((2, 1, 1, 4),) 175 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="valid", stride=1, strategy1=strategy1, strategy2=strategy2) 176 with pytest.raises(RuntimeError): 177 compile_net(net) 178 179 180def test_conv2d_output_can_not_divisible_by_strategy(): 181 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 182 strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) 183 strategy2 = ((1, 1, 1, 8),) 184 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) 185 with pytest.raises(RuntimeError): 186 compile_net(net) 187 188 189def test_conv2d_output_can_not_divisible_by_strategy2(): 190 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 191 strategy1 = ((1, 1, 8, 1), (1, 1, 1, 1)) 192 strategy2 = ((1, 1, 1, 8),) 193 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) 194 with pytest.raises(RuntimeError): 195 compile_net(net) 196 197 198def test_split_kernel(): 199 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 200 strategy1 = ((1, 1, 1, 1), (1, 1, 2, 2)) 201 strategy2 = ((1, 1, 1, 8),) 202 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=2, strategy1=strategy1, strategy2=strategy2) 203 with pytest.raises(RuntimeError): 204 compile_net(net) 205 206 207def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_same_mode(): 208 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 209 strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) 210 strategy2 = ((1, 1, 1, 8),) 211 net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) 212 with pytest.raises(RuntimeError): 213 compile_net(net, _x2) 214 215 216def test_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode(): 217 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 218 strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) 219 strategy2 = ((1, 1, 1, 8),) 220 net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2) 221 with pytest.raises(RuntimeError): 222 compile_net(net, _x2) 223 224 225def test_h_dimension_kernel_size_smaller_than_stride_and_slice_is_not_divisible_by_stride_same_mode(): 226 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 227 strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) 228 strategy2 = ((1, 1, 1, 8),) 229 net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) 230 with pytest.raises(RuntimeError): 231 compile_net(net, _x2) 232 233 234def test_h_dimension_kernel_size_smaller_than_stride_and_slice_can_not_divisible_by_stride_valid_mode(): 235 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 236 strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) 237 strategy2 = ((1, 1, 1, 8),) 238 net = Net(_w0, out_channel=8, kernel_size=1, pad_mode="valid", stride=3, strategy1=strategy1, strategy2=strategy2) 239 with pytest.raises(RuntimeError): 240 compile_net(net, _x2) 241 242 243def test_split_h_dimension_and_pad_mode_is_pad(): 244 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 245 strategy1 = ((1, 1, 2, 1), (1, 1, 1, 1)) 246 strategy2 = ((1, 1, 1, 8),) 247 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="pad", stride=2, strategy1=strategy1, strategy2=strategy2) 248 with pytest.raises(RuntimeError): 249 compile_net(net) 250 251 252def test_kernel_size_larger_than_stride_and_input_can_not_divisible_by_stride(): 253 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 254 strategy1 = ((1, 1, 1, 2), (1, 1, 1, 1)) 255 strategy2 = ((1, 1, 1, 8),) 256 net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=3, strategy1=strategy1, strategy2=strategy2) 257 with pytest.raises(RuntimeError): 258 compile_net(net, _x2) 259 260 261def test_kernel_size_larger_than_stride_and_slice_too_small(): 262 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 263 strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) 264 strategy2 = ((1, 1, 1, 8),) 265 net = Net(_w3, out_channel=8, kernel_size=5, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 266 with pytest.raises(RuntimeError): 267 compile_net(net) 268 269 270def test_conv2d_same_mode_overlap_size_equal_to_slice_shape(): 271 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 272 strategy1 = ((1, 1, 1, 8), (1, 1, 1, 1)) 273 strategy2 = ((2, 1, 1, 4),) 274 net = Net(_w2, out_channel=8, kernel_size=3, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 275 with pytest.raises(RuntimeError): 276 compile_net(net) 277 278 279def test_kernel_size_larger_than_stride_and_left_pad_is_0(): 280 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 281 strategy1 = ((1, 1, 1, 4), (1, 1, 1, 1)) 282 strategy2 = ((1, 1, 1, 8),) 283 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, strategy1=strategy1, strategy2=strategy2) 284 with pytest.raises(RuntimeError): 285 compile_net(net) 286