1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(Cell): 26 def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides, 27 strategy1=None, strategy2=None): 28 super().__init__() 29 self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, 30 pad_mode=pad_mode, stride=stride).shard(strategy1) 31 self.conv2d_weight = Parameter(conv2d_weight, "w1") 32 self.max_pool = P.MaxPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2) 33 34 def construct(self, x, b): 35 out = self.conv2d(x, self.conv2d_weight) 36 out = self.max_pool(out) 37 return out 38 39 40class Net2(Cell): 41 def __init__(self, conv2d_weight, out_channel, kernel_size, pad_mode, stride, pool_kernel_size, pool_strides, 42 strategy1=None, strategy2=None): 43 super().__init__() 44 self.conv2d = P.Conv2D(out_channel=out_channel, kernel_size=kernel_size, 45 pad_mode=pad_mode, stride=stride).shard(strategy1) 46 self.conv2d_weight = Parameter(conv2d_weight, "w1") 47 self.avg_pool = P.AvgPool(kernel_size=pool_kernel_size, strides=pool_strides).shard(strategy2) 48 49 def construct(self, x, b): 50 out = self.conv2d(x, self.conv2d_weight) 51 out = self.avg_pool(out) 52 return out 53 54 55_x0 = Tensor(np.ones([32, 16, 10, 10]), dtype=ms.float32) 56_x = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) 57_w1 = Tensor(np.ones([8, 16, 2, 2]), dtype=ms.float32) 58_b = Tensor(np.ones([32, 16, 8, 8]), dtype=ms.float32) 59 60 61def compile_net(net, inputs=_x): 62 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 63 train_net = TrainOneStepCell(net, optimizer) 64 train_net.set_auto_parallel() 65 train_net.set_train() 66 _cell_graph_executor.compile(train_net, inputs, _b) 67 context.reset_auto_parallel_context() 68 69 70def test_maxpool_data_parallel(): 71 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 72 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 73 strategy2 = ((8, 1, 1, 1),) 74 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 75 strategy1=strategy1, strategy2=strategy2) 76 compile_net(net) 77 78 79def test_maxpool_model_parallel1(): 80 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 81 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 82 strategy2 = ((2, 1, 2, 2),) 83 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 84 strategy1=strategy1, strategy2=strategy2) 85 compile_net(net) 86 87 88def test_maxpool_model_parallel2(): 89 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 90 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 91 strategy2 = ((2, 1, 2, 2),) 92 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4, 93 strategy1=strategy1, strategy2=strategy2) 94 compile_net(net) 95 96 97def test_maxpool_auto_parallel(): 98 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 99 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4) 100 compile_net(net) 101 102 103def test_maxpool_output_is_not_divisible_by_strategy_w_dimension(): 104 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 105 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 106 strategy2 = ((1, 1, 1, 8),) 107 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 108 strategy1=strategy1, strategy2=strategy2) 109 with pytest.raises(RuntimeError): 110 compile_net(net) 111 112 113def test_maxpool_output_is_not_divisible_by_strategy_h_dimension(): 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 116 strategy2 = ((1, 1, 8, 1),) 117 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 118 strategy1=strategy1, strategy2=strategy2) 119 with pytest.raises(RuntimeError): 120 compile_net(net) 121 122 123def test_maxpool_shard_h_and_kernel_size_larger_than_stride(): 124 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 125 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 126 strategy2 = ((1, 1, 2, 1),) 127 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2, 128 strategy1=strategy1, strategy2=strategy2) 129 with pytest.raises(RuntimeError): 130 compile_net(net) 131 132 133def test_maxpool_shard_w_and_kernel_size_larger_than_stride(): 134 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 135 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 136 strategy2 = ((1, 1, 1, 2),) 137 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=3, pool_strides=2, 138 strategy1=strategy1, strategy2=strategy2) 139 with pytest.raises(RuntimeError): 140 compile_net(net) 141 142 143def test_maxpool_shard_h_and_input_slice_is_not_divisible_by_stride(): 144 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 145 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 146 strategy2 = ((1, 1, 2, 1),) 147 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3, 148 strategy1=strategy1, strategy2=strategy2) 149 with pytest.raises(RuntimeError): 150 compile_net(net, inputs=_x0) 151 152 153def test_maxpool_shard_w_and_input_slice_is_not_divisible_by_stride(): 154 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 155 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 156 strategy2 = ((1, 1, 2, 1),) 157 net = Net(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=1, pool_strides=3, 158 strategy1=strategy1, strategy2=strategy2) 159 with pytest.raises(RuntimeError): 160 compile_net(net, inputs=_x0) 161 162 163def test_avgpool_data_parallel(): 164 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 165 strategy1 = ((8, 1, 1, 1), (1, 1, 1, 1)) 166 strategy2 = ((8, 1, 1, 1),) 167 net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 168 strategy1=strategy1, strategy2=strategy2) 169 compile_net(net) 170 171 172def test_avgpool_model_parallel1(): 173 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 174 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 175 strategy2 = ((2, 1, 2, 2),) 176 net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=2, 177 strategy1=strategy1, strategy2=strategy2) 178 compile_net(net) 179 180 181def test_avgpool_model_parallel2(): 182 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 183 strategy1 = ((2, 2, 1, 1), (2, 2, 1, 1)) 184 strategy2 = ((2, 1, 2, 2),) 185 net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4, 186 strategy1=strategy1, strategy2=strategy2) 187 compile_net(net) 188 189 190def test_avgpool_auto_parallel(): 191 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 192 net = Net2(_w1, out_channel=8, kernel_size=2, pad_mode="same", stride=1, pool_kernel_size=2, pool_strides=4) 193 compile_net(net) 194