• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/Types.hpp>
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 namespace armnnUtils
14 {
15 
16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout
17 class DataLayoutIndexed
18 {
19 public:
20     DataLayoutIndexed(armnn::DataLayout dataLayout);
21 
GetDataLayout() const22     armnn::DataLayout GetDataLayout()    const { return m_DataLayout; }
GetChannelsIndex() const23     unsigned int      GetChannelsIndex() const { return m_ChannelsIndex; }
GetHeightIndex() const24     unsigned int      GetHeightIndex()   const { return m_HeightIndex; }
GetWidthIndex() const25     unsigned int      GetWidthIndex()    const { return m_WidthIndex; }
GetDepthIndex() const26     unsigned int      GetDepthIndex()    const { return m_DepthIndex; }
27 
GetIndex(const armnn::TensorShape & shape,unsigned int batchIndex,unsigned int channelIndex,unsigned int heightIndex,unsigned int widthIndex) const28     inline unsigned int GetIndex(const armnn::TensorShape& shape,
29                                  unsigned int batchIndex, unsigned int channelIndex,
30                                  unsigned int heightIndex, unsigned int widthIndex) const
31     {
32         ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
33         ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
34                     ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
35         ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] ||
36                     ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
37         ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] ||
38                     ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
39 
40         /// Offset the given indices appropriately depending on the data layout
41         switch (m_DataLayout)
42         {
43         case armnn::DataLayout::NHWC:
44             batchIndex  *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
45             heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
46             widthIndex  *= shape[m_ChannelsIndex];
47             /// channelIndex stays unchanged
48             break;
49         case armnn::DataLayout::NCHW:
50         default:
51             batchIndex   *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
52             channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
53             heightIndex  *= shape[m_WidthIndex];
54             /// widthIndex stays unchanged
55             break;
56         }
57 
58         /// Get the value using the correct offset
59         return batchIndex + channelIndex + heightIndex + widthIndex;
60     }
61 
62 private:
63     armnn::DataLayout m_DataLayout;
64     unsigned int      m_ChannelsIndex;
65     unsigned int      m_HeightIndex;
66     unsigned int      m_WidthIndex;
67     unsigned int      m_DepthIndex;
68 };
69 
70 /// Equality methods
71 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed);
72 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout);
73 
74 } // namespace armnnUtils
75