1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "../ParserHelper.hpp" 7 8 #include <armnn/Tensor.hpp> 9 #include <armnn/Types.hpp> 10 11 #include <doctest/doctest.h> 12 13 14 using namespace armnn; 15 using namespace armnnUtils; 16 17 TEST_SUITE("ParserHelperSuite") 18 { 19 TEST_CASE("CalculateReducedOutputTensoInfoTest") 20 { 21 bool keepDims = false; 22 23 unsigned int inputShape[] = { 2, 3, 4 }; 24 TensorInfo inputTensorInfo(3, &inputShape[0], DataType::Float32); 25 26 // Reducing all dimensions results in one single output value (one dimension) 27 std::set<unsigned int> axisData1 = { 0, 1, 2 }; 28 TensorInfo outputTensorInfo1; 29 30 CalculateReducedOutputTensoInfo(inputTensorInfo, axisData1, keepDims, outputTensorInfo1); 31 32 CHECK(outputTensorInfo1.GetNumDimensions() == 1); 33 CHECK(outputTensorInfo1.GetShape()[0] == 1); 34 35 // Reducing dimension 0 results in a 3x4 size tensor (one dimension) 36 std::set<unsigned int> axisData2 = { 0 }; 37 TensorInfo outputTensorInfo2; 38 39 CalculateReducedOutputTensoInfo(inputTensorInfo, axisData2, keepDims, outputTensorInfo2); 40 41 CHECK(outputTensorInfo2.GetNumDimensions() == 1); 42 CHECK(outputTensorInfo2.GetShape()[0] == 12); 43 44 // Reducing dimensions 0,1 results in a 4 size tensor (one dimension) 45 std::set<unsigned int> axisData3 = { 0, 1 }; 46 TensorInfo outputTensorInfo3; 47 48 CalculateReducedOutputTensoInfo(inputTensorInfo, axisData3, keepDims, outputTensorInfo3); 49 50 CHECK(outputTensorInfo3.GetNumDimensions() == 1); 51 CHECK(outputTensorInfo3.GetShape()[0] == 4); 52 53 // Reducing dimension 0 results in a { 1, 3, 4 } dimension tensor 54 keepDims = true; 55 std::set<unsigned int> axisData4 = { 0 }; 56 57 TensorInfo outputTensorInfo4; 58 59 CalculateReducedOutputTensoInfo(inputTensorInfo, axisData4, keepDims, outputTensorInfo4); 60 61 CHECK(outputTensorInfo4.GetNumDimensions() == 3); 62 CHECK(outputTensorInfo4.GetShape()[0] == 1); 63 CHECK(outputTensorInfo4.GetShape()[1] == 3); 64 CHECK(outputTensorInfo4.GetShape()[2] == 4); 65 66 // Reducing dimension 1, 2 results in a { 2, 1, 1 } dimension tensor 67 keepDims = true; 68 std::set<unsigned int> axisData5 = { 1, 2 }; 69 70 TensorInfo outputTensorInfo5; 71 72 CalculateReducedOutputTensoInfo(inputTensorInfo, axisData5, keepDims, outputTensorInfo5); 73 74 CHECK(outputTensorInfo5.GetNumDimensions() == 3); 75 CHECK(outputTensorInfo5.GetShape()[0] == 2); 76 CHECK(outputTensorInfo5.GetShape()[1] == 1); 77 CHECK(outputTensorInfo5.GetShape()[2] == 1); 78 79 } 80 81 } 82 83