• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <string>
17 #include <vector>
18 #include "common/common_test.h"
19 #include "include/api/model.h"
20 #include "include/api/serialization.h"
21 #include "include/api/context.h"
22 
23 using namespace mindspore;
24 
25 static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir";
26 static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir";
27 static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir";
28 static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir";
29 static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir";
30 static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir";
31 static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir";
32 static constexpr float kConstValue = 0.1234;
33 static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue);
34 
35 class TestControl : public ST::Common {
36  public:
TestControl()37   TestControl() {}
38 };
39 
TEST_F(TestControl,InferIfbyIf)40 TEST_F(TestControl, InferIfbyIf) {
41   auto context = ContextAutoSet();
42 
43   Graph graph;
44   ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph));
45   Model control_model;
46   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
47 
48   // assert inputs
49   std::vector<MSTensor> inputs_before = control_model.GetInputs();
50   ASSERT_EQ(5, inputs_before.size());
51   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
52   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
53   EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool);
54   EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool);
55   EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32);
56   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float));
57   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float));
58   ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool));
59   ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool));
60   ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size());
61   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
62   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
63   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
64   EXPECT_EQ(inputs_before[1].Shape()[0], 1);
65   ASSERT_EQ(inputs_before[2].Shape().size(), 1);
66   EXPECT_EQ(inputs_before[2].Shape()[0], 1);
67   ASSERT_EQ(inputs_before[3].Shape().size(), 1);
68   EXPECT_EQ(inputs_before[3].Shape()[0], 1);
69   ASSERT_EQ(inputs_before[4].Shape().size(), 4);
70   EXPECT_EQ(inputs_before[4].Shape()[0], 2);
71   EXPECT_EQ(inputs_before[4].Shape()[1], 3);
72   EXPECT_EQ(inputs_before[4].Shape()[2], 4);
73   EXPECT_EQ(inputs_before[4].Shape()[3], 5);
74 
75   // assert outputs
76   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
77   ASSERT_EQ(1, outputs_before.size());
78   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
79   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
80   ASSERT_EQ(outputs_before[0].Shape().size(), 4);
81   EXPECT_EQ(outputs_before[0].Shape()[0], 2);
82   EXPECT_EQ(outputs_before[0].Shape()[1], 3);
83   EXPECT_EQ(outputs_before[0].Shape()[2], 4);
84   EXPECT_EQ(outputs_before[0].Shape()[3], 5);
85 
86   // prepare input
87   std::vector<MSTensor> outputs;
88   std::vector<MSTensor> inputs;
89 
90   float x = 2.345678, y = 1.234567;
91   bool cond1 = true, cond2 = false;
92   inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
93                       sizeof(float));
94   inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
95                       sizeof(float));
96   inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1,
97                       sizeof(bool));
98   inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2,
99                       sizeof(bool));
100   inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(),
101                       sizeof(float) * input_data.size());
102 
103   // infer
104   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
105 
106   // assert output
107   ASSERT_TRUE(outputs.size() == 1);
108   auto out = outputs[0];
109   ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
110   auto out_data = out.Data();
111   auto p = reinterpret_cast<const float *>(out_data.get());
112   for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
113     ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3);
114   }
115 }
116 
TEST_F(TestControl,InferSimpleWhile)117 TEST_F(TestControl, InferSimpleWhile) {
118   auto context = ContextAutoSet();
119 
120   Graph graph;
121   ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph));
122   Model control_model;
123   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
124 
125   // assert inputs
126   std::vector<MSTensor> inputs_before = control_model.GetInputs();
127   ASSERT_EQ(3, inputs_before.size());
128   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool);
129   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool);
130   EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32);
131   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool));
132   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool));
133   ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size());
134   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
135   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
136   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
137   EXPECT_EQ(inputs_before[1].Shape()[0], 1);
138   ASSERT_EQ(inputs_before[2].Shape().size(), 4);
139   EXPECT_EQ(inputs_before[2].Shape()[0], 2);
140   EXPECT_EQ(inputs_before[2].Shape()[1], 3);
141   EXPECT_EQ(inputs_before[2].Shape()[2], 4);
142   EXPECT_EQ(inputs_before[2].Shape()[3], 5);
143 
144   // assert outputs
145   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
146   ASSERT_EQ(1, outputs_before.size());
147   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
148   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
149   ASSERT_EQ(outputs_before[0].Shape().size(), 4);
150   EXPECT_EQ(outputs_before[0].Shape()[0], 2);
151   EXPECT_EQ(outputs_before[0].Shape()[1], 3);
152   EXPECT_EQ(outputs_before[0].Shape()[2], 4);
153   EXPECT_EQ(outputs_before[0].Shape()[3], 5);
154 
155   // prepare input
156   std::vector<MSTensor> outputs;
157   std::vector<MSTensor> inputs;
158   {
159     bool x = true, y = false;
160     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
161                         sizeof(bool));
162     inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
163                         sizeof(bool));
164     inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(),
165                         input_data.data(), sizeof(float) * input_data.size());
166   }
167 
168   // infer
169   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
170 
171   // assert output
172   ASSERT_TRUE(outputs.size() == 1);
173   auto out = outputs[0];
174   ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size());
175   auto out_data = out.Data();
176   auto p = reinterpret_cast<const float *>(out_data.get());
177   for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
178     ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3);
179   }
180 }
181 
TEST_F(TestControl,InferRecursive)182 TEST_F(TestControl, InferRecursive) {
183   auto context = ContextAutoSet();
184 
185   Graph graph;
186   ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph));
187   Model control_model;
188   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
189 
190   // assert inputs
191   std::vector<MSTensor> inputs_before = control_model.GetInputs();
192   ASSERT_EQ(1, inputs_before.size());
193   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
194   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
195   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
196   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
197 
198   // assert outputs
199   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
200   ASSERT_EQ(1, outputs_before.size());
201   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
202   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
203   ASSERT_EQ(outputs_before[0].Shape().size(), 1);
204   EXPECT_EQ(outputs_before[0].Shape()[0], 1);
205 
206 
207   // prepare input
208   std::vector<MSTensor> outputs;
209   std::vector<MSTensor> inputs;
210   {
211     int32_t x = 7;
212     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
213                         sizeof(int32_t));
214   }
215 
216   // infer
217   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
218 
219   // assert output
220   ASSERT_TRUE(outputs.size() == 1);
221   auto out = outputs[0];
222   ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
223   auto out_data = out.Data();
224   auto p = reinterpret_cast<const int32_t *>(out_data.get());
225   ASSERT_EQ(*p, 21);
226 }
227 
TEST_F(TestControl,InferMixedWhileIf)228 TEST_F(TestControl, InferMixedWhileIf) {
229   auto context = ContextAutoSet();
230 
231   Graph graph;
232   ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph));
233   Model control_model;
234   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
235 
236   // assert inputs
237   std::vector<MSTensor> inputs_before = control_model.GetInputs();
238   ASSERT_EQ(inputs_before.size(), 5);
239   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
240   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
241   EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
242   EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32);
243   EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32);
244   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
245   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
246   ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
247   ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t));
248   ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t));
249   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
250   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
251   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
252   EXPECT_EQ(inputs_before[1].Shape()[0], 1);
253   ASSERT_EQ(inputs_before[2].Shape().size(), 1);
254   EXPECT_EQ(inputs_before[2].Shape()[0], 1);
255   ASSERT_EQ(inputs_before[3].Shape().size(), 1);
256   EXPECT_EQ(inputs_before[3].Shape()[0], 1);
257   ASSERT_EQ(inputs_before[4].Shape().size(), 1);
258   EXPECT_EQ(inputs_before[4].Shape()[0], 1);
259 
260   // assert outputs
261   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
262   ASSERT_EQ(1, outputs_before.size());
263   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
264   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
265   ASSERT_EQ(outputs_before[0].Shape().size(), 1);
266   EXPECT_EQ(outputs_before[0].Shape()[0], 1);
267 
268   // prepare input
269   std::vector<MSTensor> outputs;
270   std::vector<MSTensor> inputs;
271   {
272     int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0;
273     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
274                         sizeof(int32_t));
275     inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
276                         sizeof(int32_t));
277     inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
278                         sizeof(int32_t));
279     inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2,
280                         sizeof(int32_t));
281     inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4,
282                         sizeof(int32_t));
283   }
284 
285   // infer
286   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
287 
288   // assert output
289   ASSERT_TRUE(outputs.size() == 1);
290   auto out = outputs[0];
291   ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
292   auto out_data = out.Data();
293   auto p = reinterpret_cast<const int32_t *>(out_data.get());
294   ASSERT_EQ(*p, 350);
295 }
296 
TEST_F(TestControl,InferSingleFor)297 TEST_F(TestControl, InferSingleFor) {
298   auto context = ContextAutoSet();
299 
300   Graph graph;
301   ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph));
302   Model control_model;
303   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
304 
305   // assert inputs
306   std::vector<MSTensor> inputs_before = control_model.GetInputs();
307   ASSERT_EQ(inputs_before.size(), 3);
308   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32);
309   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
310   EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
311   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t));
312   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
313   ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
314   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
315   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
316   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
317   EXPECT_EQ(inputs_before[1].Shape()[0], 1);
318   ASSERT_EQ(inputs_before[2].Shape().size(), 1);
319   EXPECT_EQ(inputs_before[2].Shape()[0], 1);
320 
321   // assert outputs
322   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
323   ASSERT_EQ(1, outputs_before.size());
324   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
325   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
326   ASSERT_EQ(outputs_before[0].Shape().size(), 1);
327   EXPECT_EQ(outputs_before[0].Shape()[0], 1);
328 
329   // prepare input
330   std::vector<MSTensor> outputs;
331   std::vector<MSTensor> inputs;
332   {
333     int32_t x = 2, y = 5, z = 4;
334     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x,
335                         sizeof(int32_t));
336     inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y,
337                         sizeof(int32_t));
338     inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z,
339                         sizeof(int32_t));
340   }
341 
342   // infer
343   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
344 
345   // assert output
346   ASSERT_TRUE(outputs.size() == 1);
347   auto out = outputs[0];
348   ASSERT_TRUE(out.DataSize() == sizeof(int32_t));
349   auto out_data = out.Data();
350   auto p = reinterpret_cast<const int32_t *>(out_data.get());
351   ASSERT_EQ(*p, 125);
352 }
353 
TEST_F(TestControl,InferSingleOr)354 TEST_F(TestControl, InferSingleOr) {
355   auto context = ContextAutoSet();
356 
357   Graph graph;
358   ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph));
359   Model control_model;
360   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
361 
362   // assert inputs
363   std::vector<MSTensor> inputs_before = control_model.GetInputs();
364   ASSERT_EQ(inputs_before.size(), 2);
365   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
366   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32);
367   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2);
368   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2);
369   ASSERT_EQ(inputs_before[0].Shape().size(), 1);
370   EXPECT_EQ(inputs_before[0].Shape()[0], 2);
371   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
372   EXPECT_EQ(inputs_before[1].Shape()[0], 2);
373 
374   // assert outputs
375   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
376   ASSERT_EQ(1, outputs_before.size());
377   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
378   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float));
379 
380   // prepare input
381   std::vector<MSTensor> outputs;
382   std::vector<MSTensor> inputs;
383   {
384     static const std::vector<float> input_data1 = {0, 1};
385     static const std::vector<float> input_data2 = {0, 0};
386     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
387                         input_data1.data(), sizeof(float) * input_data1.size());
388     inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(),
389                         input_data2.data(), sizeof(int32_t) * input_data2.size());
390   }
391 
392   // infer
393   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
394 
395   // assert outputs
396   std::vector<MSTensor> outputs_after = control_model.GetOutputs();
397   ASSERT_EQ(1, outputs_after.size());
398   EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32);
399   ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float));
400   EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size());
401 
402   // assert output
403   ASSERT_TRUE(outputs.size() == 1);
404   auto out = outputs[0];
405   ASSERT_TRUE(out.DataSize() == sizeof(float));
406   auto out_data = out.Data();
407   auto p = reinterpret_cast<const float *>(out_data.get());
408   ASSERT_EQ(*p, 1);
409 }
410 
TEST_F(TestControl,InferSingleSwitch)411 TEST_F(TestControl, InferSingleSwitch) {
412   auto context = ContextAutoSet();
413 
414   Graph graph;
415   ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph));
416   Model control_model;
417   ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess);
418 
419   // assert inputs
420   std::vector<MSTensor> inputs_before = control_model.GetInputs();
421   ASSERT_EQ(inputs_before.size(), 3);
422   EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32);
423   EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32);
424   EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32);
425   ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224);
426   ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t));
427   ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t));
428   ASSERT_EQ(inputs_before[0].Shape().size(), 4);
429   EXPECT_EQ(inputs_before[0].Shape()[0], 1);
430   EXPECT_EQ(inputs_before[0].Shape()[1], 1);
431   EXPECT_EQ(inputs_before[0].Shape()[2], 224);
432   EXPECT_EQ(inputs_before[0].Shape()[3], 224);
433   ASSERT_EQ(inputs_before[1].Shape().size(), 1);
434   EXPECT_EQ(inputs_before[1].Shape()[0], 1);
435   ASSERT_EQ(inputs_before[2].Shape().size(), 1);
436   EXPECT_EQ(inputs_before[2].Shape()[0], 1);
437 
438   // assert outputs
439   std::vector<MSTensor> outputs_before = control_model.GetOutputs();
440   ASSERT_EQ(1, outputs_before.size());
441   EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
442   ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224);
443   ASSERT_EQ(outputs_before[0].Shape().size(), 4);
444   EXPECT_EQ(outputs_before[0].Shape()[0], 1);
445   EXPECT_EQ(outputs_before[0].Shape()[1], 1);
446   EXPECT_EQ(outputs_before[0].Shape()[2], 224);
447   EXPECT_EQ(outputs_before[0].Shape()[3], 224);
448 
449   // prepare input
450   std::vector<MSTensor> outputs;
451   std::vector<MSTensor> inputs;
452   {
453     static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1);
454     int32_t index1 = 0;
455     int32_t index2 = -1;
456     inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(),
457                         input_data1.data(), sizeof(float) * input_data1.size());
458     inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1,
459                         sizeof(int32_t));
460     inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2,
461                         sizeof(int32_t));
462   }
463 
464   // infer
465   ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
466 
467   // assert output
468   ASSERT_TRUE(outputs.size() == 1);
469   auto out = outputs[0];
470   ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224);
471   auto out_data = out.Data();
472   auto p = reinterpret_cast<const float *>(out_data.get());
473   for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) {
474     ASSERT_EQ(p[i], 1);
475   }
476 }
477