• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
17 class DataLayoutIndexed
18 {
19 public:
20     DataLayoutIndexed(armnn::DataLayout dataLayout);
21 
GetDataLayout() const22     armnn::DataLayout GetDataLayout()    const { return m_DataLayout; }
GetChannelsIndex() const23     unsigned int      GetChannelsIndex() const { return m_ChannelsIndex; }
GetHeightIndex() const24     unsigned int      GetHeightIndex()   const { return m_HeightIndex; }
GetWidthIndex() const25     unsigned int      GetWidthIndex()    const { return m_WidthIndex; }
26 
GetIndex(const armnn::TensorShape & shape,unsigned int batchIndex,unsigned int channelIndex,unsigned int heightIndex,unsigned int widthIndex) const27     inline unsigned int GetIndex(const armnn::TensorShape& shape,
28                                  unsigned int batchIndex, unsigned int channelIndex,
29                                  unsigned int heightIndex, unsigned int widthIndex) const
30     {
31         ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
32         ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
33                     ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
34         ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] ||
35                     ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
36         ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] ||
37                     ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
38 
39         /// Offset the given indices appropriately depending on the data layout
40         switch (m_DataLayout)
41         {
42         case armnn::DataLayout::NHWC:
43             batchIndex  *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
44             heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
45             widthIndex  *= shape[m_ChannelsIndex];
46             /// channelIndex stays unchanged
47             break;
48         case armnn::DataLayout::NCHW:
49         default:
50             batchIndex   *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
51             channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
52             heightIndex  *= shape[m_WidthIndex];
53             /// widthIndex stays unchanged
54             break;
55         }
56 
57         /// Get the value using the correct offset
58         return batchIndex + channelIndex + heightIndex + widthIndex;
59     }
60 
61 private:
62     armnn::DataLayout m_DataLayout;
63     unsigned int      m_ChannelsIndex;
64     unsigned int      m_HeightIndex;
65     unsigned int      m_WidthIndex;
66 };
67 
68 /// Equality methods
69 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
70 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
71 
72 } // namespace armnnUtils
73