1 // 2 // Copyright © 2019 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 #include <armnn/Tensor.hpp> 10 11 #include <armnn/utility/Assert.hpp> 12 13 namespace armnnUtils 14 { 15 16 /// Provides access to the appropriate indexes for Channels, Height and Width based on DataLayout 17 class DataLayoutIndexed 18 { 19 public: 20 DataLayoutIndexed(armnn::DataLayout dataLayout); 21 GetDataLayout() const22 armnn::DataLayout GetDataLayout() const { return m_DataLayout; } GetChannelsIndex() const23 unsigned int GetChannelsIndex() const { return m_ChannelsIndex; } GetHeightIndex() const24 unsigned int GetHeightIndex() const { return m_HeightIndex; } GetWidthIndex() const25 unsigned int GetWidthIndex() const { return m_WidthIndex; } 26 GetIndex(const armnn::TensorShape & shape,unsigned int batchIndex,unsigned int channelIndex,unsigned int heightIndex,unsigned int widthIndex) const27 inline unsigned int GetIndex(const armnn::TensorShape& shape, 28 unsigned int batchIndex, unsigned int channelIndex, 29 unsigned int heightIndex, unsigned int widthIndex) const 30 { 31 ARMNN_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) ); 32 ARMNN_ASSERT( channelIndex < shape[m_ChannelsIndex] || 33 ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) ); 34 ARMNN_ASSERT( heightIndex < shape[m_HeightIndex] || 35 ( shape[m_HeightIndex] == 0 && heightIndex == 0) ); 36 ARMNN_ASSERT( widthIndex < shape[m_WidthIndex] || 37 ( shape[m_WidthIndex] == 0 && widthIndex == 0) ); 38 39 /// Offset the given indices appropriately depending on the data layout 40 switch (m_DataLayout) 41 { 42 case armnn::DataLayout::NHWC: 43 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex 44 heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex]; 45 widthIndex *= shape[m_ChannelsIndex]; 46 /// channelIndex stays unchanged 47 break; 48 case armnn::DataLayout::NCHW: 49 default: 50 batchIndex *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex 51 channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex]; 52 heightIndex *= shape[m_WidthIndex]; 53 /// widthIndex stays unchanged 54 break; 55 } 56 57 /// Get the value using the correct offset 58 return batchIndex + channelIndex + heightIndex + widthIndex; 59 } 60 61 private: 62 armnn::DataLayout m_DataLayout; 63 unsigned int m_ChannelsIndex; 64 unsigned int m_HeightIndex; 65 unsigned int m_WidthIndex; 66 }; 67 68 /// Equality methods 69 bool operator==(const armnn::DataLayout& dataLayout, const DataLayoutIndexed& indexed); 70 bool operator==(const DataLayoutIndexed& indexed, const armnn::DataLayout& dataLayout); 71 72 } // namespace armnnUtils 73