• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import numpy as np
16import pytest
17
18import mindspore as ms
19from mindspore import context, Tensor, Parameter
20from mindspore.nn import Cell, Momentum
21from mindspore.ops import operations as P
22from mindspore.train import Model
23from tests.dataset_mock import MindData
24
25
26class Dataset(MindData):
27    def __init__(self, predict, label, length=3):
28        super(Dataset, self).__init__(size=length)
29        self.predict = predict
30        self.label = label
31        self.index = 0
32        self.length = length
33
34    def __iter__(self):
35        return self
36
37    def __next__(self):
38        if self.index >= self.length:
39            raise StopIteration
40        self.index += 1
41        return self.predict, self.label
42
43    def reset(self):
44        self.index = 0
45
46
47class Net(Cell):
48    def __init__(self, weight, w2, begin, end, strides, strategy1=None, strategy2=None, is_parameter=True, mask=0):
49        super().__init__()
50        self.mul = P.Mul().shard(strategy1)
51        self.strided_slice = P.StridedSlice(begin_mask=mask).shard(strategy2)
52        if is_parameter:
53            self.weight = Parameter(weight, "w1")
54        else:
55            self.weight = weight
56        self.mul2 = P.Mul()
57        self.weight2 = Parameter(w2, "w2")
58        self.begin = begin
59        self.end = end
60        self.strides = strides
61
62    def construct(self, x, b):
63        out = self.strided_slice(
64            self.weight, self.begin, self.end, self.strides)
65        out = self.mul(x, out)
66        out = self.mul2(out, self.weight2)
67        return out
68
69
70class Net2(Cell):
71    def __init__(self, weight2, begin, end, strides, strategy1=None, strategy2=None):
72        super().__init__()
73        self.mul = P.Mul().shard(strategy1)
74        self.strided_slice = P.StridedSlice().shard(strategy2)
75        self.weight2 = Parameter(weight2, "w2")
76        self.begin = begin
77        self.end = end
78        self.strides = strides
79
80    def construct(self, x, b):
81        out = self.mul(x, self.weight2)
82        out = self.strided_slice(out, self.begin, self.end, self.strides)
83        return out
84
85
86_x = Tensor(np.ones([16, 64, 1]), dtype=ms.float32)
87_b = Tensor(np.ones([16, 64, 32]), dtype=ms.float32)
88_w1 = Tensor(np.ones([256, 64, 32]), dtype=ms.float32)
89_w2 = Tensor(np.ones([128, 64, 1]), dtype=ms.float32)
90
91
92def compile_net(net):
93    learning_rate = 0.1
94    momentum = 0.9
95    epoch_size = 2
96    dataset = Dataset(_x, _b)
97    opt = Momentum(net.trainable_params(), learning_rate, momentum)
98    model = Model(net, optimizer=opt, amp_level="O2")
99    model.train(epoch_size, dataset, dataset_sink_mode=False)
100    context.reset_auto_parallel_context()
101
102
103def test_stridedslice_no_fully_fetch_split_error():
104    context.set_auto_parallel_context(
105        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
106    strategy1 = ((2, 2, 2), (2, 2, 2))
107    strategy2 = ((2, 2, 2),)
108    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
109              strategy1, strategy2, is_parameter=True)
110    with pytest.raises(RuntimeError):
111        compile_net(net)
112
113
114def test_stridedslice_strides_no_1_split_error():
115    context.set_auto_parallel_context(
116        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
117    strategy1 = ((2, 2, 2), (2, 2, 2))
118    strategy2 = ((1, 2, 2),)
119    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 2),
120              strategy1, strategy2, is_parameter=True)
121    with pytest.raises(RuntimeError):
122        compile_net(net)
123
124
125def test_stridedslice_mask_no_0_split_error():
126    context.set_auto_parallel_context(
127        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
128    strategy1 = ((2, 2, 2), (2, 2, 2))
129    strategy2 = ((1, 2, 2),)
130    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
131              strategy1, strategy2, is_parameter=True, mask=1)
132    with pytest.raises(RuntimeError):
133        compile_net(net)
134
135
136def test_stridedslice_begin_size_smaller():
137    context.set_auto_parallel_context(
138        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
139    strategy1 = ((1, 4, 1), (1, 4, 2))
140    strategy2 = ((1, 4, 2),)
141    net = Net(_w1, _w2, (0, 0), (128, 64), (1, 1),
142              strategy1, strategy2, is_parameter=True)
143    compile_net(net)
144
145
146def test_stridedslice_parameter():
147    context.set_auto_parallel_context(
148        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
149    strategy1 = ((1, 4, 1), (1, 4, 2))
150    strategy2 = ((1, 4, 2),)
151    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
152              strategy1, strategy2, is_parameter=True)
153    compile_net(net)
154
155
156def test_stridedslice_tensor():
157    context.set_auto_parallel_context(
158        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
159    strategy1 = ((1, 4, 1), (1, 4, 2))
160    strategy2 = ((1, 4, 2),)
161    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
162              strategy1, strategy2, is_parameter=False)
163    compile_net(net)
164
165
166def test_stridedslice_parameter_no_full_split():
167    context.set_auto_parallel_context(
168        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
169    strategy1 = ((1, 4, 1), (1, 4, 2))
170    strategy2 = ((1, 2, 2),)
171    net = Net(_w1, _w2, (0, 0, 0), (128, 64, 32), (1, 1, 1),
172              strategy1, strategy2, is_parameter=True)
173    compile_net(net)
174
175
176def test_stridedslice_output():
177    context.set_auto_parallel_context(
178        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
179    strategy1 = ((1, 8, 1), (1, 8, 1))
180    strategy2 = ((1, 8, 1),)
181    net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
182    compile_net(net)
183
184
185def test_stridedslice_output_no_full_split():
186    context.set_auto_parallel_context(
187        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
188    strategy1 = ((1, 8, 1), (1, 8, 1))
189    strategy2 = ((1, 4, 1),)
190    net = Net2(_w2, (0, 0, 0), (64, 64, 1), (1, 1, 1), strategy1, strategy2)
191    compile_net(net)
192
193
194def test_stridedslice_no_strategy():
195    context.set_auto_parallel_context(
196        parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
197    strategy1 = ((1, 8, 1), (1, 8, 1))
198    strategy2 = None
199    net = Net2(_w2, (0, 0, 0), (128, 64, 1), (1, 1, 1), strategy1, strategy2)
200    compile_net(net)
201
202
203def test_stridedslice_auto_parallel():
204    context.set_auto_parallel_context(
205        parallel_mode="auto_parallel", device_num=8, global_rank=0)
206    net = Net2(_w2, (0, 0, 0), (32, 64, 1), (1, 1, 1))
207    compile_net(net)
208