1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.common.api import _cell_graph_executor 21from mindspore.nn import Cell, TrainOneStepCell, Momentum 22from mindspore.ops import operations as P 23 24 25class Net(Cell): 26 def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0): 27 super().__init__() 28 self.mul = P.Mul().shard(strategy1) 29 self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2) 30 if is_parameter: 31 self.weight = Parameter(weight, "w1") 32 else: 33 self.weight = weight 34 self.mul2 = P.Mul() 35 self.weight2 = Parameter(w2, "w2") 36 self.begin = begin 37 self.end = end 38 self.strides = strides 39 40 def construct(self, x, b): 41 out = self.strided_slice(self.weight, self.begin, self.end, self.strides) 42 out = self.mul(x, out) 43 out = self.mul2(out, self.weight2) 44 return out 45 46 47class Net2(Cell): 48 def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None): 49 super().__init__() 50 self.mul = P.Mul().shard(strategy1) 51 self.strided_slice = P.StridedSlice().shard(strategy2) 52 self.weight2 = Parameter(weight2, "w2") 53 self.begin = begin 54 self.end = end 55 self.strides = strides 56 57 def construct(self, x, b): 58 out = self.mul(x, self.weight2) 59 out = self.strided_slice(out, self.begin, self.end, self.strides) 60 return out 61 62 63_x = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) 64_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) 65_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) 66_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) 67 68 69def compile_net(net): 70 optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) 71 train_net = TrainOneStepCell(net, optimizer) 72 train_net.set_auto_parallel() 73 train_net.set_train() 74 _cell_graph_executor.compile(train_net, _x, _b) 75 context.reset_auto_parallel_context() 76 77 78def test_stridedslice_no_fully_fetch_split_error(): 79 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 80 strategy1 = ((2, 2, 2), (2, 2, 2)) 81 strategy2 = ((2, 2, 2),) 82 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) 83 with pytest.raises(RuntimeError): 84 compile_net(net) 85 86 87def test_stridedslice_strides_no_1_split_error(): 88 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 89 strategy1 = ((2, 2, 2), (2, 2, 2)) 90 strategy2 = ((1, 2, 2),) 91 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), strategy1, strategy2, is_parameter=True) 92 with pytest.raises(RuntimeError): 93 compile_net(net) 94 95 96def test_stridedslice_mask_no_0_split_error(): 97 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 98 strategy1 = ((2, 2, 2), (2, 2, 2)) 99 strategy2 = ((1, 2, 2),) 100 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True, mask=1) 101 with pytest.raises(RuntimeError): 102 compile_net(net) 103 104 105def test_stridedslice_begin_size_smaller(): 106 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 107 strategy1 = ((1, 4, 1), (1, 4, 2)) 108 strategy2 = ((1, 4, 2),) 109 net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), strategy1, strategy2, is_parameter=True) 110 compile_net(net) 111 112 113def test_stridedslice_parameter(): 114 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 115 strategy1 = ((1, 4, 1), (1, 4, 2)) 116 strategy2 = ((1, 4, 2),) 117 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) 118 compile_net(net) 119 120 121def test_stridedslice_tensor(): 122 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 123 strategy1 = ((1, 4, 1), (1, 4, 2)) 124 strategy2 = ((1, 4, 2),) 125 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=False) 126 compile_net(net) 127 128 129def test_stridedslice_parameter_no_full_split(): 130 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 131 strategy1 = ((1, 4, 1), (1, 4, 2)) 132 strategy2 = ((1, 2, 2),) 133 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), strategy1, strategy2, is_parameter=True) 134 compile_net(net) 135 136 137def test_stridedslice_output(): 138 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 139 strategy1 = ((1, 8, 1), (1, 8, 1)) 140 strategy2 = ((1, 8, 1),) 141 net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) 142 compile_net(net) 143 144 145def test_stridedslice_output_no_full_split(): 146 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 147 strategy1 = ((1, 8, 1), (1, 8, 1)) 148 strategy2 = ((1, 4, 1),) 149 net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) 150 compile_net(net) 151 152 153def test_stridedslice_no_strategy(): 154 context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 155 strategy1 = ((1, 8, 1), (1, 8, 1)) 156 strategy2 = None 157 net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2) 158 compile_net(net) 159 160 161def test_stridedslice_auto_parallel(): 162 context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0) 163 net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1)) 164 compile_net(net) 165