1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <string> 17 #include <vector> 18 #include "common/common_test.h" 19 #include "include/api/model.h" 20 #include "include/api/serialization.h" 21 #include "include/api/context.h" 22 23 using namespace mindspore; 24 25 static constexpr char kIfbyIfFile[] = "/home/workspace/mindspore_dataset/mindir/control/ifbyif.mindir"; 26 static constexpr char kSimpleWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/simple_while.mindir"; 27 static constexpr char kMixIfWhileFile[] = "/home/workspace/mindspore_dataset/mindir/control/mix_while_if.mindir"; 28 static constexpr char kRecursiveFile[] = "/home/workspace/mindspore_dataset/mindir/control/fibonacci.mindir"; 29 static constexpr char kSingleForFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_for.mindir"; 30 static constexpr char kSingleOrFile[] = "/home/workspace/mindspore_dataset/mindir/control/single_or.mindir"; 31 static constexpr char kSingleSwitchFile[] = "/home/workspace/mindspore_dataset/mindir/control/switch_layer_net.mindir"; 32 static constexpr float kConstValue = 0.1234; 33 static const std::vector<float> input_data(2 * 3 * 4 * 5, kConstValue); 34 35 class TestControl : public ST::Common { 36 public: 37 TestControl() {} 38 }; 39 40 TEST_F(TestControl, InferIfbyIf) { 41 auto context = ContextAutoSet(); 42 43 Graph graph; 44 ASSERT_TRUE(Serialization::Load(kIfbyIfFile, ModelType::kMindIR, &graph)); 45 Model control_model; 46 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 47 48 // assert inputs 49 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 50 ASSERT_EQ(5, inputs_before.size()); 51 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); 52 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32); 53 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeBool); 54 EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeBool); 55 EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeFloat32); 56 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float)); 57 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float)); 58 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(bool)); 59 ASSERT_EQ(inputs_before[3].DataSize(), sizeof(bool)); 60 ASSERT_EQ(inputs_before[4].DataSize(), sizeof(float) * input_data.size()); 61 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 62 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 63 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 64 EXPECT_EQ(inputs_before[1].Shape()[0], 1); 65 ASSERT_EQ(inputs_before[2].Shape().size(), 1); 66 EXPECT_EQ(inputs_before[2].Shape()[0], 1); 67 ASSERT_EQ(inputs_before[3].Shape().size(), 1); 68 EXPECT_EQ(inputs_before[3].Shape()[0], 1); 69 ASSERT_EQ(inputs_before[4].Shape().size(), 4); 70 EXPECT_EQ(inputs_before[4].Shape()[0], 2); 71 EXPECT_EQ(inputs_before[4].Shape()[1], 3); 72 EXPECT_EQ(inputs_before[4].Shape()[2], 4); 73 EXPECT_EQ(inputs_before[4].Shape()[3], 5); 74 75 // assert outputs 76 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 77 ASSERT_EQ(1, outputs_before.size()); 78 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); 79 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size()); 80 ASSERT_EQ(outputs_before[0].Shape().size(), 4); 81 EXPECT_EQ(outputs_before[0].Shape()[0], 2); 82 EXPECT_EQ(outputs_before[0].Shape()[1], 3); 83 EXPECT_EQ(outputs_before[0].Shape()[2], 4); 84 EXPECT_EQ(outputs_before[0].Shape()[3], 5); 85 86 // prepare input 87 std::vector<MSTensor> outputs; 88 std::vector<MSTensor> inputs; 89 90 float x = 2.345678, y = 1.234567; 91 bool cond1 = true, cond2 = false; 92 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, 93 sizeof(float)); 94 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, 95 sizeof(float)); 96 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &cond1, 97 sizeof(bool)); 98 inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &cond2, 99 sizeof(bool)); 100 inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), input_data.data(), 101 sizeof(float) * input_data.size()); 102 103 // infer 104 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 105 106 // assert output 107 ASSERT_TRUE(outputs.size() == 1); 108 auto out = outputs[0]; 109 ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size()); 110 auto out_data = out.Data(); 111 auto p = reinterpret_cast<const float *>(out_data.get()); 112 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { 113 ASSERT_LE(std::abs(p[i] - kConstValue * 24), 1e-3); 114 } 115 } 116 117 TEST_F(TestControl, InferSimpleWhile) { 118 auto context = ContextAutoSet(); 119 120 Graph graph; 121 ASSERT_TRUE(Serialization::Load(kSimpleWhileFile, ModelType::kMindIR, &graph)); 122 Model control_model; 123 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 124 125 // assert inputs 126 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 127 ASSERT_EQ(3, inputs_before.size()); 128 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeBool); 129 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeBool); 130 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeFloat32); 131 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(bool)); 132 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(bool)); 133 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(float) * input_data.size()); 134 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 135 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 136 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 137 EXPECT_EQ(inputs_before[1].Shape()[0], 1); 138 ASSERT_EQ(inputs_before[2].Shape().size(), 4); 139 EXPECT_EQ(inputs_before[2].Shape()[0], 2); 140 EXPECT_EQ(inputs_before[2].Shape()[1], 3); 141 EXPECT_EQ(inputs_before[2].Shape()[2], 4); 142 EXPECT_EQ(inputs_before[2].Shape()[3], 5); 143 144 // assert outputs 145 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 146 ASSERT_EQ(1, outputs_before.size()); 147 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); 148 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size()); 149 ASSERT_EQ(outputs_before[0].Shape().size(), 4); 150 EXPECT_EQ(outputs_before[0].Shape()[0], 2); 151 EXPECT_EQ(outputs_before[0].Shape()[1], 3); 152 EXPECT_EQ(outputs_before[0].Shape()[2], 4); 153 EXPECT_EQ(outputs_before[0].Shape()[3], 5); 154 155 // prepare input 156 std::vector<MSTensor> outputs; 157 std::vector<MSTensor> inputs; 158 { 159 bool x = true, y = false; 160 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, 161 sizeof(bool)); 162 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, 163 sizeof(bool)); 164 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), 165 input_data.data(), sizeof(float) * input_data.size()); 166 } 167 168 // infer 169 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 170 171 // assert output 172 ASSERT_TRUE(outputs.size() == 1); 173 auto out = outputs[0]; 174 ASSERT_TRUE(out.DataSize() == sizeof(float) * input_data.size()); 175 auto out_data = out.Data(); 176 auto p = reinterpret_cast<const float *>(out_data.get()); 177 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { 178 ASSERT_LE(std::abs(p[i] - kConstValue * 3), 1e-3); 179 } 180 } 181 182 TEST_F(TestControl, InferRecursive) { 183 auto context = ContextAutoSet(); 184 185 Graph graph; 186 ASSERT_TRUE(Serialization::Load(kRecursiveFile, ModelType::kMindIR, &graph)); 187 Model control_model; 188 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 189 190 // assert inputs 191 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 192 ASSERT_EQ(1, inputs_before.size()); 193 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); 194 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); 195 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 196 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 197 198 // assert outputs 199 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 200 ASSERT_EQ(1, outputs_before.size()); 201 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); 202 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); 203 ASSERT_EQ(outputs_before[0].Shape().size(), 1); 204 EXPECT_EQ(outputs_before[0].Shape()[0], 1); 205 206 207 // prepare input 208 std::vector<MSTensor> outputs; 209 std::vector<MSTensor> inputs; 210 { 211 int32_t x = 7; 212 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, 213 sizeof(int32_t)); 214 } 215 216 // infer 217 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 218 219 // assert output 220 ASSERT_TRUE(outputs.size() == 1); 221 auto out = outputs[0]; 222 ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); 223 auto out_data = out.Data(); 224 auto p = reinterpret_cast<const int32_t *>(out_data.get()); 225 ASSERT_EQ(*p, 21); 226 } 227 228 TEST_F(TestControl, InferMixedWhileIf) { 229 auto context = ContextAutoSet(); 230 231 Graph graph; 232 ASSERT_TRUE(Serialization::Load(kMixIfWhileFile, ModelType::kMindIR, &graph)); 233 Model control_model; 234 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 235 236 // assert inputs 237 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 238 ASSERT_EQ(inputs_before.size(), 5); 239 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); 240 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); 241 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); 242 EXPECT_EQ(inputs_before[3].DataType(), DataType::kNumberTypeInt32); 243 EXPECT_EQ(inputs_before[4].DataType(), DataType::kNumberTypeInt32); 244 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); 245 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); 246 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); 247 ASSERT_EQ(inputs_before[3].DataSize(), sizeof(int32_t)); 248 ASSERT_EQ(inputs_before[4].DataSize(), sizeof(int32_t)); 249 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 250 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 251 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 252 EXPECT_EQ(inputs_before[1].Shape()[0], 1); 253 ASSERT_EQ(inputs_before[2].Shape().size(), 1); 254 EXPECT_EQ(inputs_before[2].Shape()[0], 1); 255 ASSERT_EQ(inputs_before[3].Shape().size(), 1); 256 EXPECT_EQ(inputs_before[3].Shape()[0], 1); 257 ASSERT_EQ(inputs_before[4].Shape().size(), 1); 258 EXPECT_EQ(inputs_before[4].Shape()[0], 1); 259 260 // assert outputs 261 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 262 ASSERT_EQ(1, outputs_before.size()); 263 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); 264 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); 265 ASSERT_EQ(outputs_before[0].Shape().size(), 1); 266 EXPECT_EQ(outputs_before[0].Shape()[0], 1); 267 268 // prepare input 269 std::vector<MSTensor> outputs; 270 std::vector<MSTensor> inputs; 271 { 272 int32_t x = 2, y = 14, z = 1, c2 = 14, c4 = 0; 273 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, 274 sizeof(int32_t)); 275 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, 276 sizeof(int32_t)); 277 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z, 278 sizeof(int32_t)); 279 inputs.emplace_back(inputs_before[3].Name(), inputs_before[3].DataType(), inputs_before[3].Shape(), &c2, 280 sizeof(int32_t)); 281 inputs.emplace_back(inputs_before[4].Name(), inputs_before[4].DataType(), inputs_before[4].Shape(), &c4, 282 sizeof(int32_t)); 283 } 284 285 // infer 286 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 287 288 // assert output 289 ASSERT_TRUE(outputs.size() == 1); 290 auto out = outputs[0]; 291 ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); 292 auto out_data = out.Data(); 293 auto p = reinterpret_cast<const int32_t *>(out_data.get()); 294 ASSERT_EQ(*p, 350); 295 } 296 297 TEST_F(TestControl, InferSingleFor) { 298 auto context = ContextAutoSet(); 299 300 Graph graph; 301 ASSERT_TRUE(Serialization::Load(kSingleForFile, ModelType::kMindIR, &graph)); 302 Model control_model; 303 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 304 305 // assert inputs 306 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 307 ASSERT_EQ(inputs_before.size(), 3); 308 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeInt32); 309 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); 310 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); 311 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(int32_t)); 312 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); 313 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); 314 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 315 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 316 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 317 EXPECT_EQ(inputs_before[1].Shape()[0], 1); 318 ASSERT_EQ(inputs_before[2].Shape().size(), 1); 319 EXPECT_EQ(inputs_before[2].Shape()[0], 1); 320 321 // assert outputs 322 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 323 ASSERT_EQ(1, outputs_before.size()); 324 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); 325 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); 326 ASSERT_EQ(outputs_before[0].Shape().size(), 1); 327 EXPECT_EQ(outputs_before[0].Shape()[0], 1); 328 329 // prepare input 330 std::vector<MSTensor> outputs; 331 std::vector<MSTensor> inputs; 332 { 333 int32_t x = 2, y = 5, z = 4; 334 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), &x, 335 sizeof(int32_t)); 336 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &y, 337 sizeof(int32_t)); 338 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &z, 339 sizeof(int32_t)); 340 } 341 342 // infer 343 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 344 345 // assert output 346 ASSERT_TRUE(outputs.size() == 1); 347 auto out = outputs[0]; 348 ASSERT_TRUE(out.DataSize() == sizeof(int32_t)); 349 auto out_data = out.Data(); 350 auto p = reinterpret_cast<const int32_t *>(out_data.get()); 351 ASSERT_EQ(*p, 125); 352 } 353 354 TEST_F(TestControl, InferSingleOr) { 355 auto context = ContextAutoSet(); 356 357 Graph graph; 358 ASSERT_TRUE(Serialization::Load(kSingleOrFile, ModelType::kMindIR, &graph)); 359 Model control_model; 360 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 361 362 // assert inputs 363 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 364 ASSERT_EQ(inputs_before.size(), 2); 365 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); 366 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeFloat32); 367 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 2); 368 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(float) * 2); 369 ASSERT_EQ(inputs_before[0].Shape().size(), 1); 370 EXPECT_EQ(inputs_before[0].Shape()[0], 2); 371 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 372 EXPECT_EQ(inputs_before[1].Shape()[0], 2); 373 374 // assert outputs 375 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 376 ASSERT_EQ(1, outputs_before.size()); 377 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); 378 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float)); 379 380 // prepare input 381 std::vector<MSTensor> outputs; 382 std::vector<MSTensor> inputs; 383 { 384 static const std::vector<float> input_data1 = {0, 1}; 385 static const std::vector<float> input_data2 = {0, 0}; 386 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), 387 input_data1.data(), sizeof(float) * input_data1.size()); 388 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), 389 input_data2.data(), sizeof(int32_t) * input_data2.size()); 390 } 391 392 // infer 393 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 394 395 // assert outputs 396 std::vector<MSTensor> outputs_after = control_model.GetOutputs(); 397 ASSERT_EQ(1, outputs_after.size()); 398 EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32); 399 ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float)); 400 EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size()); 401 402 // assert output 403 ASSERT_TRUE(outputs.size() == 1); 404 auto out = outputs[0]; 405 ASSERT_TRUE(out.DataSize() == sizeof(float)); 406 auto out_data = out.Data(); 407 auto p = reinterpret_cast<const float *>(out_data.get()); 408 ASSERT_EQ(*p, 1); 409 } 410 411 TEST_F(TestControl, InferSingleSwitch) { 412 auto context = ContextAutoSet(); 413 414 Graph graph; 415 ASSERT_TRUE(Serialization::Load(kSingleSwitchFile, ModelType::kMindIR, &graph)); 416 Model control_model; 417 ASSERT_TRUE(control_model.Build(GraphCell(graph), context) == kSuccess); 418 419 // assert inputs 420 std::vector<MSTensor> inputs_before = control_model.GetInputs(); 421 ASSERT_EQ(inputs_before.size(), 3); 422 EXPECT_EQ(inputs_before[0].DataType(), DataType::kNumberTypeFloat32); 423 EXPECT_EQ(inputs_before[1].DataType(), DataType::kNumberTypeInt32); 424 EXPECT_EQ(inputs_before[2].DataType(), DataType::kNumberTypeInt32); 425 ASSERT_EQ(inputs_before[0].DataSize(), sizeof(float) * 224 * 224); 426 ASSERT_EQ(inputs_before[1].DataSize(), sizeof(int32_t)); 427 ASSERT_EQ(inputs_before[2].DataSize(), sizeof(int32_t)); 428 ASSERT_EQ(inputs_before[0].Shape().size(), 4); 429 EXPECT_EQ(inputs_before[0].Shape()[0], 1); 430 EXPECT_EQ(inputs_before[0].Shape()[1], 1); 431 EXPECT_EQ(inputs_before[0].Shape()[2], 224); 432 EXPECT_EQ(inputs_before[0].Shape()[3], 224); 433 ASSERT_EQ(inputs_before[1].Shape().size(), 1); 434 EXPECT_EQ(inputs_before[1].Shape()[0], 1); 435 ASSERT_EQ(inputs_before[2].Shape().size(), 1); 436 EXPECT_EQ(inputs_before[2].Shape()[0], 1); 437 438 // assert outputs 439 std::vector<MSTensor> outputs_before = control_model.GetOutputs(); 440 ASSERT_EQ(1, outputs_before.size()); 441 EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); 442 ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224); 443 ASSERT_EQ(outputs_before[0].Shape().size(), 4); 444 EXPECT_EQ(outputs_before[0].Shape()[0], 1); 445 EXPECT_EQ(outputs_before[0].Shape()[1], 1); 446 EXPECT_EQ(outputs_before[0].Shape()[2], 224); 447 EXPECT_EQ(outputs_before[0].Shape()[3], 224); 448 449 // prepare input 450 std::vector<MSTensor> outputs; 451 std::vector<MSTensor> inputs; 452 { 453 static const std::vector<float> input_data1(1 * 1 * 224 * 224, 1); 454 int32_t index1 = 0; 455 int32_t index2 = -1; 456 inputs.emplace_back(inputs_before[0].Name(), inputs_before[0].DataType(), inputs_before[0].Shape(), 457 input_data1.data(), sizeof(float) * input_data1.size()); 458 inputs.emplace_back(inputs_before[1].Name(), inputs_before[1].DataType(), inputs_before[1].Shape(), &index1, 459 sizeof(int32_t)); 460 inputs.emplace_back(inputs_before[2].Name(), inputs_before[2].DataType(), inputs_before[2].Shape(), &index2, 461 sizeof(int32_t)); 462 } 463 464 // infer 465 ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); 466 467 // assert output 468 ASSERT_TRUE(outputs.size() == 1); 469 auto out = outputs[0]; 470 ASSERT_TRUE(out.DataSize() == sizeof(float) * 224 * 224); 471 auto out_data = out.Data(); 472 auto p = reinterpret_cast<const float *>(out_data.get()); 473 for (size_t i = 0; i < out.DataSize() / sizeof(float); ++i) { 474 ASSERT_EQ(p[i], 1); 475 } 476 } 477