1 // 2 // Copyright © 2019 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 #include <armnn/Tensor.hpp> 10 11 #include <armnn/utility/Assert.hpp> 12 13 namespace armnnUtils 14 { 15 16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout 17 class DataLayoutIndexed 18 { 19 public: 20 DataLayoutIndexed(armnn::DataLayout dataLayout); 21 GetDataLayout() const22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; } GetChannelsIndex() const23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; } GetHeightIndex() const24 unsigned int GetHeightIndex() const { return m_HeightIndex; } GetWidthIndex() const25 unsigned int GetWidthIndex() const { return m_WidthIndex; } GetDepthIndex() const26 unsigned int GetDepthIndex() const { return m_DepthIndex; } 27 GetIndex(const armnn::TensorShape & shape,unsigned int batchIndex,unsigned int channelIndex,unsigned int heightIndex,unsigned int widthIndex) const28 inline unsigned int GetIndex(const armnn::TensorShape& shape, 29 unsigned int batchIndex, unsigned int channelIndex, 30 unsigned int heightIndex, unsigned int widthIndex) const 31 { 32 ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) ); 33 ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] || 34 ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) ); 35 ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] || 36 ( shape[m_HeightIndex] == 0 && heightIndex == 0) ); 37 ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] || 38 ( shape[m_WidthIndex] == 0 && widthIndex == 0) ); 39 40 /// Offset the given indices appropriately depending on the data layout 41 switch (m_DataLayout) 42 { 43 case armnn::DataLayout::NHWC: 44 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex 45 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex]; 46 widthIndex *= shape[m_ChannelsIndex]; 47 /// channelIndex stays unchanged 48 break; 49 case armnn::DataLayout::NCHW: 50 default: 51 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex 52 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex]; 53 heightIndex *= shape[m_WidthIndex]; 54 /// widthIndex stays unchanged 55 break; 56 } 57 58 /// Get the value using the correct offset 59 return batchIndex + channelIndex + heightIndex + widthIndex; 60 } 61 62 private: 63 armnn::DataLayout m_DataLayout; 64 unsigned int m_ChannelsIndex; 65 unsigned int m_HeightIndex; 66 unsigned int m_WidthIndex; 67 unsigned int m_DepthIndex; 68 }; 69 70 /// Equality methods 71 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed); 72 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout); 73 74 } // namespace armnnUtils 75