1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "LayersFwd.hpp" 7 #include <Network.hpp> 8 #include <doctest/doctest.h> 9 #include <Optimizer.hpp> 10 #include <TestUtils.hpp> 11 12 TEST_SUITE("Optimizer") 13 { 14 using namespace armnn; 15 using namespace armnn::optimizations; 16 17 TEST_CASE("ConvertConstPermuteToConst") 18 { 19 Graph graph; 20 const unsigned int shape[] = {1, 2, 2, 3}; 21 22 const TensorInfo constTensorInfo(4, shape, DataType::Float32, 1.0, 0, true); 23 24 ConstantLayer* constant = graph.AddLayer<ConstantLayer>("constant"); 25 std::vector<float> constantValues(constTensorInfo.GetNumElements(), 4.5f); 26 ConstTensor constTensor(constTensorInfo, constantValues.data()); 27 constant->m_LayerOutput = std::make_shared<ScopedTensorHandle>(constTensor); 28 constant->GetOutputSlot().SetTensorInfo(constTensorInfo); 29 30 PermuteDescriptor desc({ 0, 2, 3, 1 }); 31 PermuteLayer* permuteLayer = graph.AddLayer<PermuteLayer>(desc, "permute"); 32 TensorInfo infoPermuted = armnnUtils::Permuted(constTensorInfo, { 0, 2, 3, 1 }); 33 permuteLayer->GetOutputSlot().SetTensorInfo(infoPermuted); 34 35 OutputLayer* output = graph.AddLayer<OutputLayer>(0, "output"); 36 37 // Connect up constant -> permute -> output 38 constant->GetOutputSlot().Connect(permuteLayer->GetInputSlot(0)); 39 permuteLayer->GetOutputSlot().Connect(output->GetInputSlot(0)); 40 41 CHECK(CheckSequence(graph.cbegin(), graph.cend(), 42 &IsLayerOfType<ConstantLayer>, 43 &IsLayerOfType<PermuteLayer>, 44 &IsLayerOfType<OutputLayer>)); 45 46 armnn::Optimizer::Pass(graph, MakeOptimizations(FusePermuteIntoConstLayer())); 47 48 CHECK(CheckSequence(graph.cbegin(), graph.cend(), 49 &IsLayerOfType<ConstantLayer>, 50 &IsLayerOfType<OutputLayer>)); 51 52 TensorShape tensorShape = constant->GetOutputSlot(0).GetTensorInfo().GetShape(); 53 CHECK(tensorShape[0] == shape[0]); 54 CHECK(tensorShape[1] == shape[3]); 55 CHECK(tensorShape[2] == shape[1]); 56 CHECK(tensorShape[3] == shape[2]); 57 58 } 59 60 } 61