1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <string>
18 #include <list>
19 #include <vector>
20 #include "common/common_test.h"
21 #include "frontend/parallel/strategy.h"
22 #include "frontend/parallel/ops_info/reshape_info.h"
23 #include "frontend/parallel/device_manager.h"
24 #include "frontend/parallel/step_parallel.h"
25
26 namespace mindspore {
27 namespace parallel {
28
29 class ReshapeInfo;
30 using ReshapeInfoPtr = std::shared_ptr<ReshapeInfo>;
31 ReshapeInfoPtr reshape;
32
33 class TestReshapeInfo : public UT::Common {
34 public:
TestReshapeInfo()35 TestReshapeInfo() {}
36 void SetUp();
TearDown()37 void TearDown() {}
38 };
39
SetUp()40 void TestReshapeInfo::SetUp() {
41 RankList dev_list;
42
43 for (int32_t i = 0; i < 34; i++) {
44 dev_list.push_back(i);
45 }
46
47 RankList stage_map;
48 stage_map.push_back(32);
49 stage_map.push_back(2);
50
51 int32_t local_dev = 0;
52
53 // create a new g_device_manager
54 g_device_manager = std::make_shared<DeviceManager>();
55 g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
56
57 std::unordered_map<std::string, ValuePtr> attr;
58
59 Shapes inputs_shape = {{32, 512, 7, 7}};
60 Shapes outputs_shape = {{32, 25088}};
61 std::vector<int64_t> axis = {32, 25088};
62 ValuePtr val0;
63 ValuePtr val1 = MakeValue(axis);
64 std::vector<ValuePtr> val = {val0, val1};
65
66 reshape = std::make_shared<ReshapeInfo>("reshape_info", inputs_shape, outputs_shape, attr);
67 reshape->set_input_value(val);
68 }
69
TEST_F(TestReshapeInfo,InferDevMatrixShape1)70 TEST_F(TestReshapeInfo, InferDevMatrixShape1) {
71 Strategys inputs = {{4, 1, 1, 1}};
72 StrategyPtr strategy = NewStrategy(0, inputs);
73
74 reshape->Init(strategy);
75 Shape dev_matrix_shape = reshape->dev_matrix_shape();
76
77 Shape expect = {4, 1, 1, 1, 8};
78 ASSERT_EQ(dev_matrix_shape, expect);
79 }
80
TEST_F(TestReshapeInfo,InferDevMatrixShape2)81 TEST_F(TestReshapeInfo, InferDevMatrixShape2) {
82 Strategys inputs = {{32, 1, 1, 1}};
83 StrategyPtr strategy = NewStrategy(0, inputs);
84
85 reshape->Init(strategy);
86 Shape dev_matrix_shape = reshape->dev_matrix_shape();
87
88 Shape expect = {32, 1, 1, 1};
89 ASSERT_EQ(dev_matrix_shape, expect);
90 }
91
TEST_F(TestReshapeInfo,InferSliceShape1)92 TEST_F(TestReshapeInfo, InferSliceShape1) {
93 Strategys str = {{4, 1, 1, 1}};
94 StrategyPtr strategy = NewStrategy(0, str);
95
96 reshape->Init(strategy);
97 std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
98 std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
99
100 Shape input_slice_shape_expect = {8, 512, 7, 7};
101 Shape output_slice_shape_expect = {32, 25088};
102
103 TensorInfo input_tensor_info = inputs.at(0);
104 TensorInfo output_tensor_info = outputs.at(0);
105
106 Shape input_slice_shape = input_tensor_info.slice_shape();
107 Shape output_slice_shape = output_tensor_info.slice_shape();
108
109 ASSERT_EQ(input_slice_shape, input_slice_shape_expect);
110 ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
111 }
112
TEST_F(TestReshapeInfo,InferSliceShape2)113 TEST_F(TestReshapeInfo, InferSliceShape2) {
114 Strategys str = {{32, 1, 1, 1}};
115 StrategyPtr strategy = NewStrategy(0, str);
116
117 reshape->Init(strategy);
118 std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
119 std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
120
121 Shape input_slice_shape_expect = {1, 512, 7, 7};
122 Shape output_slice_shape_expect = {32, 25088};
123
124 TensorInfo input_tensor_info = inputs.at(0);
125 TensorInfo output_tensor_info = outputs.at(0);
126
127 Shape input_slice_shape = input_tensor_info.slice_shape();
128 Shape output_slice_shape = output_tensor_info.slice_shape();
129
130 ASSERT_EQ(input_slice_shape, input_slice_shape_expect);
131 ASSERT_EQ(output_slice_shape, output_slice_shape_expect);
132 }
133
TEST_F(TestReshapeInfo,GetTensorLayout1)134 TEST_F(TestReshapeInfo, GetTensorLayout1) {
135 Strategys str = {{4, 1, 1, 1}};
136 StrategyPtr strategy = NewStrategy(0, str);
137
138 reshape->Init(strategy);
139 std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
140 std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
141
142 TensorMap input_expect = {4, 3, 2, 1};
143 TensorMap output_expect = {-1, -1};
144
145 TensorInfo input_tensor_info = inputs.at(0);
146 TensorInfo output_tensor_info = outputs.at(0);
147
148 Map input_tensor_map = input_tensor_info.tensor_layout().origin_tensor_map();
149 Map output_tensor_map = output_tensor_info.tensor_layout().origin_tensor_map();
150
151 ASSERT_EQ(input_tensor_map.array(), input_expect);
152 ASSERT_EQ(output_tensor_map.array(), output_expect);
153 }
154
TEST_F(TestReshapeInfo,GetTensorLayout2)155 TEST_F(TestReshapeInfo, GetTensorLayout2) {
156 Strategys str = {{32, 1, 1, 1}};
157 StrategyPtr strategy = NewStrategy(0, str);
158
159 reshape->Init(strategy);
160 std::vector<TensorInfo> inputs = reshape->inputs_tensor_info();
161 std::vector<TensorInfo> outputs = reshape->outputs_tensor_info();
162
163 TensorMap input_expect = {3, 2, 1, 0};
164 TensorMap output_expect = {-1, -1};
165
166 TensorInfo input_tensor_info = inputs.at(0);
167 TensorInfo output_tensor_info = outputs.at(0);
168
169 Map input_tensor_map = input_tensor_info.tensor_layout().origin_tensor_map();
170 Map output_tensor_map = output_tensor_info.tensor_layout().origin_tensor_map();
171
172 ASSERT_EQ(input_tensor_map.array(), input_expect);
173 ASSERT_EQ(output_tensor_map.array(), output_expect);
174 }
175
TEST_F(TestReshapeInfo,GetForwardOp1)176 TEST_F(TestReshapeInfo, GetForwardOp1) {
177 Strategys inputs = {{4, 1, 1, 1}};
178 StrategyPtr strategy = NewStrategy(0, inputs);
179
180 reshape->Init(strategy);
181 OperatorVector forward_op = reshape->forward_op();
182 size_t size = forward_op.size();
183
184 ASSERT_EQ(size, 0);
185 }
186
TEST_F(TestReshapeInfo,GetMirrorOPs1)187 TEST_F(TestReshapeInfo, GetMirrorOPs1) {
188 Strategys inputs = {{4, 1, 1, 1}};
189 StrategyPtr strategy = NewStrategy(0, inputs);
190
191 reshape->Init(strategy);
192 MirrorOps mirror_ops = reshape->mirror_ops();
193
194 size_t size = mirror_ops.size();
195
196 ASSERT_EQ(size, 2);
197 }
198
TEST_F(TestReshapeInfo,CheckStrategy1)199 TEST_F(TestReshapeInfo, CheckStrategy1) {
200 Strategys inputs = {{1, 4, 8}};
201 StrategyPtr strategy = NewStrategy(0, inputs);
202
203 Status ret = reshape->Init(strategy);
204 ASSERT_EQ(ret, FAILED);
205 }
206
TEST_F(TestReshapeInfo,CheckStrategy2)207 TEST_F(TestReshapeInfo, CheckStrategy2) {
208 Strategys inputs = {{2, 4, 8}, {2, 4, 8}};
209 StrategyPtr strategy = NewStrategy(0, inputs);
210
211 Status ret = reshape->Init(strategy);
212 ASSERT_EQ(ret, FAILED);
213 }
214
TEST_F(TestReshapeInfo,CheckStrategy3)215 TEST_F(TestReshapeInfo, CheckStrategy3) {
216 Strategys inputs = {{4, 1, 1, 1}};
217 StrategyPtr strategy = NewStrategy(0, inputs);
218
219 Status ret = reshape->Init(strategy);
220 ASSERT_EQ(ret, SUCCESS);
221 }
222 } // namespace parallel
223 } // namespace mindspore
224