1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15import numpy as np 16import pytest 17 18import mindspore as ms 19from mindspore import context, Tensor, Parameter 20from mindspore.nn import Cell, Momentum 21from mindspore.ops import operations as P 22from mindspore.train import Model 23from tests.dataset_mock import MindData 24 25 26class Dataset(MindData): 27 def __init__(self, predict, label, length=3): 28 super(Dataset, self).__init__(size=length) 29 self.predict = predict 30 self.label = label 31 self.index = 0 32 self.length = length 33 34 def __iter__(self): 35 return self 36 37 def __next__(self): 38 if self.index >= self.length: 39 raise StopIteration 40 self.index += 1 41 return self.predict, self.label 42 43 def reset(self): 44 self.index = 0 45 46 47class Net(Cell): 48 def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0): 49 super().__init__() 50 self.mul = P.Mul().shard(strategy1) 51 self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2) 52 if is_parameter: 53 self.weight = Parameter(weight, "w1") 54 else: 55 self.weight = weight 56 self.mul2 = P.Mul() 57 self.weight2 = Parameter(w2, "w2") 58 self.begin = begin 59 self.end = end 60 self.strides = strides 61 62 def construct(self, x, b): 63 out = self.strided_slice( 64 self.weight, self.begin, self.end, self.strides) 65 out = self.mul(x, out) 66 out = self.mul2(out, self.weight2) 67 return out 68 69 70class Net2(Cell): 71 def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None): 72 super().__init__() 73 self.mul = P.Mul().shard(strategy1) 74 self.strided_slice = P.StridedSlice().shard(strategy2) 75 self.weight2 = Parameter(weight2, "w2") 76 self.begin = begin 77 self.end = end 78 self.strides = strides 79 80 def construct(self, x, b): 81 out = self.mul(x, self.weight2) 82 out = self.strided_slice(out, self.begin, self.end, self.strides) 83 return out 84 85 86_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32) 87_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32) 88_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32) 89_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32) 90 91 92def compile_net(net): 93 learning_rate = 0.1 94 momentum = 0.9 95 epoch_size = 2 96 dataset = Dataset(_x, _b) 97 opt = Momentum(net.trainable_params(), learning_rate, momentum) 98 model = Model(net, optimizer=opt, amp_level="O2") 99 model.train(epoch_size, dataset, dataset_sink_mode=False) 100 context.reset_auto_parallel_context() 101 102 103def test_stridedslice_no_fully_fetch_split_error(): 104 context.set_auto_parallel_context( 105 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 106 strategy1 = ((2, 2, 2), (2, 2, 2)) 107 strategy2 = ((2, 2, 2),) 108 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 109 strategy1, strategy2, is_parameter=True) 110 with pytest.raises(RuntimeError): 111 compile_net(net) 112 113 114def test_stridedslice_strides_no_1_split_error(): 115 context.set_auto_parallel_context( 116 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 117 strategy1 = ((2, 2, 2), (2, 2, 2)) 118 strategy2 = ((1, 2, 2),) 119 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2), 120 strategy1, strategy2, is_parameter=True) 121 with pytest.raises(RuntimeError): 122 compile_net(net) 123 124 125def test_stridedslice_mask_no_0_split_error(): 126 context.set_auto_parallel_context( 127 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 128 strategy1 = ((2, 2, 2), (2, 2, 2)) 129 strategy2 = ((1, 2, 2),) 130 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 131 strategy1, strategy2, is_parameter=True, mask=1) 132 with pytest.raises(RuntimeError): 133 compile_net(net) 134 135 136def test_stridedslice_begin_size_smaller(): 137 context.set_auto_parallel_context( 138 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 139 strategy1 = ((1, 4, 1), (1, 4, 2)) 140 strategy2 = ((1, 4, 2),) 141 net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1), 142 strategy1, strategy2, is_parameter=True) 143 compile_net(net) 144 145 146def test_stridedslice_parameter(): 147 context.set_auto_parallel_context( 148 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 149 strategy1 = ((1, 4, 1), (1, 4, 2)) 150 strategy2 = ((1, 4, 2),) 151 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 152 strategy1, strategy2, is_parameter=True) 153 compile_net(net) 154 155 156def test_stridedslice_tensor(): 157 context.set_auto_parallel_context( 158 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 159 strategy1 = ((1, 4, 1), (1, 4, 2)) 160 strategy2 = ((1, 4, 2),) 161 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 162 strategy1, strategy2, is_parameter=False) 163 compile_net(net) 164 165 166def test_stridedslice_parameter_no_full_split(): 167 context.set_auto_parallel_context( 168 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 169 strategy1 = ((1, 4, 1), (1, 4, 2)) 170 strategy2 = ((1, 2, 2),) 171 net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1), 172 strategy1, strategy2, is_parameter=True) 173 compile_net(net) 174 175 176def test_stridedslice_output(): 177 context.set_auto_parallel_context( 178 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 179 strategy1 = ((1, 8, 1), (1, 8, 1)) 180 strategy2 = ((1, 8, 1),) 181 net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) 182 compile_net(net) 183 184 185def test_stridedslice_output_no_full_split(): 186 context.set_auto_parallel_context( 187 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 188 strategy1 = ((1, 8, 1), (1, 8, 1)) 189 strategy2 = ((1, 4, 1),) 190 net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2) 191 compile_net(net) 192 193 194def test_stridedslice_no_strategy(): 195 context.set_auto_parallel_context( 196 parallel_mode="semi_auto_parallel", device_num=8, global_rank=0) 197 strategy1 = ((1, 8, 1), (1, 8, 1)) 198 strategy2 = None 199 net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2) 200 compile_net(net) 201 202 203def test_stridedslice_auto_parallel(): 204 context.set_auto_parallel_context( 205 parallel_mode="auto_parallel", device_num=8, global_rank=0) 206 net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1)) 207 compile_net(net) 208