1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <string>
17 #include <vector>
18 #include "common/common_test.h"
19 #include "include/api/model.h"
20 #include "include/api/serialization.h"
21 #include "include/api/context.h"
22
23 using namespace mindspore;
24
25 static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir";
26 static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir";
27 static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir";
28 static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir";
29 static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir";
30 static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir";
31 static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir";
32 static constexpr float kConstValue = 0.1234;
33 static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue);
34
35 class TestControl : public ST::Common {
36 public:
TestControl()37 TestControl() {}
38 };
39
TEST_F(TestControl,InferIfbyIf)40 TEST_F(TestControl, InferIfbyIf) {
41 auto context = ContextAutoSet();
42
43 Graph graph;
44 ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph));
45 Model control_model;
46 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
47
48 // assert inputs
49 std::vector<MSTensor> inputs_before = control_model.GetInputs();
50 ASSERT_EQ(5, inputs_before.size());
51 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
52 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
53 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool);
54 EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool);
55 EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32);
56 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float));
57 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float));
58 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool));
59 ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool));
60 ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size());
61 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
62 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
63 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
64 EXPECT_EQ(inputs_before[1].Shape()[0], 1);
65 ASSERT_EQ(inputs_before[2].Shape().size(), 1);
66 EXPECT_EQ(inputs_before[2].Shape()[0], 1);
67 ASSERT_EQ(inputs_before[3].Shape().size(), 1);
68 EXPECT_EQ(inputs_before[3].Shape()[0], 1);
69 ASSERT_EQ(inputs_before[4].Shape().size(), 4);
70 EXPECT_EQ(inputs_before[4].Shape()[0], 2);
71 EXPECT_EQ(inputs_before[4].Shape()[1], 3);
72 EXPECT_EQ(inputs_before[4].Shape()[2], 4);
73 EXPECT_EQ(inputs_before[4].Shape()[3], 5);
74
75 // assert outputs
76 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
77 ASSERT_EQ(1, outputs_before.size());
78 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
79 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
80 ASSERT_EQ(outputs_before[0].Shape().size(), 4);
81 EXPECT_EQ(outputs_before[0].Shape()[0], 2);
82 EXPECT_EQ(outputs_before[0].Shape()[1], 3);
83 EXPECT_EQ(outputs_before[0].Shape()[2], 4);
84 EXPECT_EQ(outputs_before[0].Shape()[3], 5);
85
86 // prepare input
87 std::vector<MSTensor> outputs;
88 std::vector<MSTensor> inputs;
89
90 float x = 2.345678, y = 1.234567;
91 bool cond1 = true, cond2 = false;
92 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
93 sizeof(float));
94 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
95 sizeof(float));
96 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1,
97 sizeof(bool));
98 inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2,
99 sizeof(bool));
100 inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(),
101 sizeof(float) * input_data.size());
102
103 // infer
104 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
105
106 // assert output
107 ASSERT_TRUE(outputs.size() == 1);
108 auto out = outputs[0];
109 ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
110 auto out_data = out.Data();
111 auto p = reinterpret_cast<const float *>(out_data.get());
112 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
113 ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3);
114 }
115 }
116
TEST_F(TestControl,InferSimpleWhile)117 TEST_F(TestControl, InferSimpleWhile) {
118 auto context = ContextAutoSet();
119
120 Graph graph;
121 ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph));
122 Model control_model;
123 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
124
125 // assert inputs
126 std::vector<MSTensor> inputs_before = control_model.GetInputs();
127 ASSERT_EQ(3, inputs_before.size());
128 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool);
129 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool);
130 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32);
131 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool));
132 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool));
133 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size());
134 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
135 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
136 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
137 EXPECT_EQ(inputs_before[1].Shape()[0], 1);
138 ASSERT_EQ(inputs_before[2].Shape().size(), 4);
139 EXPECT_EQ(inputs_before[2].Shape()[0], 2);
140 EXPECT_EQ(inputs_before[2].Shape()[1], 3);
141 EXPECT_EQ(inputs_before[2].Shape()[2], 4);
142 EXPECT_EQ(inputs_before[2].Shape()[3], 5);
143
144 // assert outputs
145 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
146 ASSERT_EQ(1, outputs_before.size());
147 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
148 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
149 ASSERT_EQ(outputs_before[0].Shape().size(), 4);
150 EXPECT_EQ(outputs_before[0].Shape()[0], 2);
151 EXPECT_EQ(outputs_before[0].Shape()[1], 3);
152 EXPECT_EQ(outputs_before[0].Shape()[2], 4);
153 EXPECT_EQ(outputs_before[0].Shape()[3], 5);
154
155 // prepare input
156 std::vector<MSTensor> outputs;
157 std::vector<MSTensor> inputs;
158 {
159 bool x = true, y = false;
160 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
161 sizeof(bool));
162 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
163 sizeof(bool));
164 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(),
165 input_data.data(), sizeof(float) * input_data.size());
166 }
167
168 // infer
169 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
170
171 // assert output
172 ASSERT_TRUE(outputs.size() == 1);
173 auto out = outputs[0];
174 ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
175 auto out_data = out.Data();
176 auto p = reinterpret_cast<const float *>(out_data.get());
177 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
178 ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3);
179 }
180 }
181
TEST_F(TestControl,InferRecursive)182 TEST_F(TestControl, InferRecursive) {
183 auto context = ContextAutoSet();
184
185 Graph graph;
186 ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph));
187 Model control_model;
188 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
189
190 // assert inputs
191 std::vector<MSTensor> inputs_before = control_model.GetInputs();
192 ASSERT_EQ(1, inputs_before.size());
193 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
194 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
195 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
196 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
197
198 // assert outputs
199 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
200 ASSERT_EQ(1, outputs_before.size());
201 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
202 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
203 ASSERT_EQ(outputs_before[0].Shape().size(), 1);
204 EXPECT_EQ(outputs_before[0].Shape()[0], 1);
205
206
207 // prepare input
208 std::vector<MSTensor> outputs;
209 std::vector<MSTensor> inputs;
210 {
211 int32_t x = 7;
212 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
213 sizeof(int32_t));
214 }
215
216 // infer
217 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
218
219 // assert output
220 ASSERT_TRUE(outputs.size() == 1);
221 auto out = outputs[0];
222 ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
223 auto out_data = out.Data();
224 auto p = reinterpret_cast<const int32_t *>(out_data.get());
225 ASSERT_EQ(*p, 21);
226 }
227
TEST_F(TestControl,InferMixedWhileIf)228 TEST_F(TestControl, InferMixedWhileIf) {
229 auto context = ContextAutoSet();
230
231 Graph graph;
232 ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph));
233 Model control_model;
234 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
235
236 // assert inputs
237 std::vector<MSTensor> inputs_before = control_model.GetInputs();
238 ASSERT_EQ(inputs_before.size(), 5);
239 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
240 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
241 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
242 EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32);
243 EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32);
244 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
245 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
246 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
247 ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t));
248 ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t));
249 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
250 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
251 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
252 EXPECT_EQ(inputs_before[1].Shape()[0], 1);
253 ASSERT_EQ(inputs_before[2].Shape().size(), 1);
254 EXPECT_EQ(inputs_before[2].Shape()[0], 1);
255 ASSERT_EQ(inputs_before[3].Shape().size(), 1);
256 EXPECT_EQ(inputs_before[3].Shape()[0], 1);
257 ASSERT_EQ(inputs_before[4].Shape().size(), 1);
258 EXPECT_EQ(inputs_before[4].Shape()[0], 1);
259
260 // assert outputs
261 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
262 ASSERT_EQ(1, outputs_before.size());
263 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
264 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
265 ASSERT_EQ(outputs_before[0].Shape().size(), 1);
266 EXPECT_EQ(outputs_before[0].Shape()[0], 1);
267
268 // prepare input
269 std::vector<MSTensor> outputs;
270 std::vector<MSTensor> inputs;
271 {
272 int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0;
273 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
274 sizeof(int32_t));
275 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
276 sizeof(int32_t));
277 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
278 sizeof(int32_t));
279 inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2,
280 sizeof(int32_t));
281 inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4,
282 sizeof(int32_t));
283 }
284
285 // infer
286 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
287
288 // assert output
289 ASSERT_TRUE(outputs.size() == 1);
290 auto out = outputs[0];
291 ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
292 auto out_data = out.Data();
293 auto p = reinterpret_cast<const int32_t *>(out_data.get());
294 ASSERT_EQ(*p, 350);
295 }
296
TEST_F(TestControl,InferSingleFor)297 TEST_F(TestControl, InferSingleFor) {
298 auto context = ContextAutoSet();
299
300 Graph graph;
301 ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph));
302 Model control_model;
303 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
304
305 // assert inputs
306 std::vector<MSTensor> inputs_before = control_model.GetInputs();
307 ASSERT_EQ(inputs_before.size(), 3);
308 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
309 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
310 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
311 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
312 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
313 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
314 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
315 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
316 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
317 EXPECT_EQ(inputs_before[1].Shape()[0], 1);
318 ASSERT_EQ(inputs_before[2].Shape().size(), 1);
319 EXPECT_EQ(inputs_before[2].Shape()[0], 1);
320
321 // assert outputs
322 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
323 ASSERT_EQ(1, outputs_before.size());
324 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
325 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
326 ASSERT_EQ(outputs_before[0].Shape().size(), 1);
327 EXPECT_EQ(outputs_before[0].Shape()[0], 1);
328
329 // prepare input
330 std::vector<MSTensor> outputs;
331 std::vector<MSTensor> inputs;
332 {
333 int32_t x = 2, y = 5, z = 4;
334 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
335 sizeof(int32_t));
336 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
337 sizeof(int32_t));
338 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
339 sizeof(int32_t));
340 }
341
342 // infer
343 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
344
345 // assert output
346 ASSERT_TRUE(outputs.size() == 1);
347 auto out = outputs[0];
348 ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
349 auto out_data = out.Data();
350 auto p = reinterpret_cast<const int32_t *>(out_data.get());
351 ASSERT_EQ(*p, 125);
352 }
353
TEST_F(TestControl,InferSingleOr)354 TEST_F(TestControl, InferSingleOr) {
355 auto context = ContextAutoSet();
356
357 Graph graph;
358 ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph));
359 Model control_model;
360 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
361
362 // assert inputs
363 std::vector<MSTensor> inputs_before = control_model.GetInputs();
364 ASSERT_EQ(inputs_before.size(), 2);
365 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
366 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
367 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2);
368 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2);
369 ASSERT_EQ(inputs_before[0].Shape().size(), 1);
370 EXPECT_EQ(inputs_before[0].Shape()[0], 2);
371 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
372 EXPECT_EQ(inputs_before[1].Shape()[0], 2);
373
374 // assert outputs
375 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
376 ASSERT_EQ(1, outputs_before.size());
377 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
378 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float));
379
380 // prepare input
381 std::vector<MSTensor> outputs;
382 std::vector<MSTensor> inputs;
383 {
384 static const std::vector<float> input_data1 = {0, 1};
385 static const std::vector<float> input_data2 = {0, 0};
386 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
387 input_data1.data(), sizeof(float) * input_data1.size());
388 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(),
389 input_data2.data(), sizeof(int32_t) * input_data2.size());
390 }
391
392 // infer
393 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
394
395 // assert outputs
396 std::vector<MSTensor> outputs_after = control_model.GetOutputs();
397 ASSERT_EQ(1, outputs_after.size());
398 EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32);
399 ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float));
400 EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size());
401
402 // assert output
403 ASSERT_TRUE(outputs.size() == 1);
404 auto out = outputs[0];
405 ASSERT_TRUE(out.DataSize() == sizeof(float));
406 auto out_data = out.Data();
407 auto p = reinterpret_cast<const float *>(out_data.get());
408 ASSERT_EQ(*p, 1);
409 }
410
TEST_F(TestControl,InferSingleSwitch)411 TEST_F(TestControl, InferSingleSwitch) {
412 auto context = ContextAutoSet();
413
414 Graph graph;
415 ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph));
416 Model control_model;
417 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
418
419 // assert inputs
420 std::vector<MSTensor> inputs_before = control_model.GetInputs();
421 ASSERT_EQ(inputs_before.size(), 3);
422 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
423 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
424 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
425 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224);
426 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
427 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
428 ASSERT_EQ(inputs_before[0].Shape().size(), 4);
429 EXPECT_EQ(inputs_before[0].Shape()[0], 1);
430 EXPECT_EQ(inputs_before[0].Shape()[1], 1);
431 EXPECT_EQ(inputs_before[0].Shape()[2], 224);
432 EXPECT_EQ(inputs_before[0].Shape()[3], 224);
433 ASSERT_EQ(inputs_before[1].Shape().size(), 1);
434 EXPECT_EQ(inputs_before[1].Shape()[0], 1);
435 ASSERT_EQ(inputs_before[2].Shape().size(), 1);
436 EXPECT_EQ(inputs_before[2].Shape()[0], 1);
437
438 // assert outputs
439 std::vector<MSTensor> outputs_before = control_model.GetOutputs();
440 ASSERT_EQ(1, outputs_before.size());
441 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
442 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224);
443 ASSERT_EQ(outputs_before[0].Shape().size(), 4);
444 EXPECT_EQ(outputs_before[0].Shape()[0], 1);
445 EXPECT_EQ(outputs_before[0].Shape()[1], 1);
446 EXPECT_EQ(outputs_before[0].Shape()[2], 224);
447 EXPECT_EQ(outputs_before[0].Shape()[3], 224);
448
449 // prepare input
450 std::vector<MSTensor> outputs;
451 std::vector<MSTensor> inputs;
452 {
453 static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1);
454 int32_t index1 = 0;
455 int32_t index2 = -1;
456 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
457 input_data1.data(), sizeof(float) * input_data1.size());
458 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1,
459 sizeof(int32_t));
460 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2,
461 sizeof(int32_t));
462 }
463
464 // infer
465 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
466
467 // assert output
468 ASSERT_TRUE(outputs.size() == 1);
469 auto out = outputs[0];
470 ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224);
471 auto out_data = out.Data();
472 auto p = reinterpret_cast<const float *>(out_data.get());
473 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
474 ASSERT_EQ(p[i], 1);
475 }
476 }
477