• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/ConstantMemoryStrategy.hpp>
7 #include <backendsCommon/memoryOptimizerStrategyLibrary/strategies/StrategyValidator.hpp>
8 
9 #include <doctest/doctest.h>
10 #include <vector>
11 
12 using namespace armnn;
13 
14 TEST_SUITE("ConstMemoryStrategyTestSuite")
15 {
16 
17 TEST_CASE("ConstMemoryStrategyTest")
18 {
19     // create a few memory blocks
20     MemBlock memBlock0(0, 2, 20, 0, 0);
21     MemBlock memBlock1(2, 3, 10, 20, 1);
22     MemBlock memBlock2(3, 5, 15, 30, 2);
23     MemBlock memBlock3(5, 6, 20, 50, 3);
24     MemBlock memBlock4(7, 8, 5, 70, 4);
25 
26     std::vector<MemBlock> memBlocks;
27     memBlocks.reserve(5);
28     memBlocks.push_back(memBlock0);
29     memBlocks.push_back(memBlock1);
30     memBlocks.push_back(memBlock2);
31     memBlocks.push_back(memBlock3);
32     memBlocks.push_back(memBlock4);
33 
34     // Optimize the memory blocks with ConstantMemoryStrategy
35     ConstantMemoryStrategy constLayerMemoryOptimizerStrategy;
36     CHECK_EQ(constLayerMemoryOptimizerStrategy.GetName(), std::string("ConstantMemoryStrategy"));
37     CHECK_EQ(constLayerMemoryOptimizerStrategy.GetMemBlockStrategyType(), MemBlockStrategyType::SingleAxisPacking);
38     auto memBins = constLayerMemoryOptimizerStrategy.Optimize(memBlocks);
39     CHECK(memBins.size() == 5);
40 
41     CHECK(memBins[1].m_MemBlocks.size() == 1);
42     CHECK(memBins[1].m_MemBlocks[0].m_Offset == 0);
43     CHECK(memBins[1].m_MemBlocks[0].m_MemSize == 10);
44     CHECK(memBins[1].m_MemBlocks[0].m_Index == 1);
45 
46     CHECK(memBins[4].m_MemBlocks.size() == 1);
47     CHECK(memBins[4].m_MemBlocks[0].m_Offset == 0);
48     CHECK(memBins[4].m_MemBlocks[0].m_MemSize == 5);
49     CHECK(memBins[4].m_MemBlocks[0].m_Index == 4);
50 }
51 
52 TEST_CASE("ConstLayerMemoryOptimizerStrategyValidatorTest")
53 {
54     // create a few memory blocks
55     MemBlock memBlock0(0, 2, 20, 0, 0);
56     MemBlock memBlock1(2, 3, 10, 20, 1);
57     MemBlock memBlock2(3, 5, 15, 30, 2);
58     MemBlock memBlock3(5, 6, 20, 50, 3);
59     MemBlock memBlock4(7, 8, 5, 70, 4);
60 
61     std::vector<MemBlock> memBlocks;
62     memBlocks.reserve(5);
63     memBlocks.push_back(memBlock0);
64     memBlocks.push_back(memBlock1);
65     memBlocks.push_back(memBlock2);
66     memBlocks.push_back(memBlock3);
67     memBlocks.push_back(memBlock4);
68 
69     // Optimize the memory blocks with ConstLayerMemoryOptimizerStrategy
70     auto ptr = std::make_shared<ConstantMemoryStrategy>();
71     StrategyValidator validator;
72     validator.SetStrategy(ptr);
73     // Ensure ConstLayerMemoryOptimizerStrategy is valid
74     CHECK_NOTHROW(validator.Optimize(memBlocks));
75 }
76 
77 }
78