1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <TestUtils.hpp> 7 8 #include <Optimizer.hpp> 9 10 #include <doctest/doctest.h> 11 12 using namespace armnn; 13 14 TEST_SUITE("Optimizer") 15 { 16 using namespace armnn::optimizations; 17 18 TEST_CASE("TransposeAsReshapeTest") 19 { 20 armnn::Graph graph; 21 22 std::string transposeLayerName = "transpose"; 23 24 const armnn::TensorInfo infoIn({ 1, 2, 3, 1 }, armnn::DataType::Float32); 25 const armnn::TensorInfo infoOut({ 1, 1, 2, 3 }, armnn::DataType::Float32); 26 27 auto output = graph.AddLayer<armnn::OutputLayer>(0, "output"); 28 29 graph.InsertNewLayer<armnn::InputLayer>(output->GetInputSlot(0), 0, "input") 30 ->GetOutputHandler() 31 .SetTensorInfo(infoIn); 32 33 // Inserts transpose. 34 graph 35 .InsertNewLayer<armnn::TransposeLayer>(output->GetInputSlot(0), armnn::TransposeDescriptor({ 0, 3, 1, 2 }), 36 transposeLayerName.c_str()) 37 ->GetOutputHandler() 38 .SetTensorInfo(infoOut); 39 40 CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, 41 &IsLayerOfType<armnn::TransposeLayer>, &IsLayerOfType<armnn::OutputLayer>)); 42 43 armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(TransposeAsReshape())); 44 45 // The transpose is replaced by an equivalent reshape. 46 __anonaa734b490102(const armnn::Layer* const layer) 47 auto checkReshape = [&infoOut](const armnn::Layer* const layer) -> bool { 48 const auto reshapeLayer = static_cast<const armnn::ReshapeLayer*>(layer); 49 return IsLayerOfType<armnn::ReshapeLayer>(layer) && 50 (reshapeLayer->GetParameters().m_TargetShape == infoOut.GetShape()) && 51 (reshapeLayer->GetOutputHandler().GetTensorInfo().GetShape() == infoOut.GetShape()); 52 }; 53 54 CHECK(CheckSequence(graph.cbegin(), graph.cend(), &IsLayerOfType<armnn::InputLayer>, checkReshape, 55 &IsLayerOfType<armnn::OutputLayer>)); 56 57 std::list<std::string> testRelatedLayers = { transposeLayerName }; 58 CHECK(CheckRelatedLayers<armnn::ReshapeLayer>(graph, testRelatedLayers)); 59 } 60 61 }