1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/ConstantMemoryStrategy.hpp> 7 #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.hpp> 8 9 #include <doctest/doctest.h> 10 #include <vector> 11 12 using namespace armnn; 13 14 TEST_SUITE("ConstMemoryStrategyTestSuite") 15 { 16 17 TEST_CASE("ConstMemoryStrategyTest") 18 { 19 // create a few memory blocks 20 MemBlock memBlock0(0, 2, 20, 0, 0); 21 MemBlock memBlock1(2, 3, 10, 20, 1); 22 MemBlock memBlock2(3, 5, 15, 30, 2); 23 MemBlock memBlock3(5, 6, 20, 50, 3); 24 MemBlock memBlock4(7, 8, 5, 70, 4); 25 26 std::vector<MemBlock> memBlocks; 27 memBlocks.reserve(5); 28 memBlocks.push_back(memBlock0); 29 memBlocks.push_back(memBlock1); 30 memBlocks.push_back(memBlock2); 31 memBlocks.push_back(memBlock3); 32 memBlocks.push_back(memBlock4); 33 34 // Optimize the memory blocks with ConstantMemoryStrategy 35 ConstantMemoryStrategy constLayerMemoryOptimizerStrategy; 36 CHECK_EQ(constLayerMemoryOptimizerStrategy.GetName(), std::string("ConstantMemoryStrategy")); 37 CHECK_EQ(constLayerMemoryOptimizerStrategy.GetMemBlockStrategyType(), MemBlockStrategyType::SingleAxisPacking); 38 auto memBins = constLayerMemoryOptimizerStrategy.Optimize(memBlocks); 39 CHECK(memBins.size() == 5); 40 41 CHECK(memBins[1].m_MemBlocks.size() == 1); 42 CHECK(memBins[1].m_MemBlocks[0].m_Offset == 0); 43 CHECK(memBins[1].m_MemBlocks[0].m_MemSize == 10); 44 CHECK(memBins[1].m_MemBlocks[0].m_Index == 1); 45 46 CHECK(memBins[4].m_MemBlocks.size() == 1); 47 CHECK(memBins[4].m_MemBlocks[0].m_Offset == 0); 48 CHECK(memBins[4].m_MemBlocks[0].m_MemSize == 5); 49 CHECK(memBins[4].m_MemBlocks[0].m_Index == 4); 50 } 51 52 TEST_CASE("ConstLayerMemoryOptimizerStrategyValidatorTest") 53 { 54 // create a few memory blocks 55 MemBlock memBlock0(0, 2, 20, 0, 0); 56 MemBlock memBlock1(2, 3, 10, 20, 1); 57 MemBlock memBlock2(3, 5, 15, 30, 2); 58 MemBlock memBlock3(5, 6, 20, 50, 3); 59 MemBlock memBlock4(7, 8, 5, 70, 4); 60 61 std::vector<MemBlock> memBlocks; 62 memBlocks.reserve(5); 63 memBlocks.push_back(memBlock0); 64 memBlocks.push_back(memBlock1); 65 memBlocks.push_back(memBlock2); 66 memBlocks.push_back(memBlock3); 67 memBlocks.push_back(memBlock4); 68 69 // Optimize the memory blocks with ConstLayerMemoryOptimizerStrategy 70 auto ptr = std::make_shared<ConstantMemoryStrategy>(); 71 StrategyValidator validator; 72 validator.SetStrategy(ptr); 73 // Ensure ConstLayerMemoryOptimizerStrategy is valid 74 CHECK_NOTHROW(validator.Optimize(memBlocks)); 75 } 76 77 } 78