• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserHelper.hpp"
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnnUtils/Permute.hpp>
10 
11 #include <fmt/format.h>
12 
13 namespace armnnUtils
14 {
15 
16 const armnn::PermutationVector NHWCToArmNN = { 0, 2, 3, 1 };
17 const armnn::PermutationVector ArmNNToNHWC = { 0, 3, 1, 2 };
18 
ProcessConcatInputTensorInfo(armnn::TensorInfo & inputTensorInfo,armnn::OriginsDescriptor & concatDescriptor,const unsigned int & concatAxis,unsigned int inputIndex,unsigned int & mergeDimOrigin)19 void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo,
20                                   armnn::OriginsDescriptor& concatDescriptor,
21                                   const unsigned int& concatAxis,
22                                   unsigned int inputIndex,
23                                   unsigned int& mergeDimOrigin)
24 {
25     const uint32_t inputRank = concatDescriptor.GetNumDimensions();
26 
27     // double check dimensions of the tensors
28     if (inputTensorInfo.GetNumDimensions() != inputRank)
29     {
30         throw armnn::ParseException(fmt::format(
31                                     "The number of dimensions: {0} for input tensors of the "
32                                     "concatenation op should be {1} {2}",
33                                     inputTensorInfo.GetNumDimensions(),
34                                     inputRank,
35                                     CHECK_LOCATION().AsString()));
36     }
37 
38     for (unsigned int j = 0; j < concatAxis; ++j)
39     {
40         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
41     }
42 
43     concatDescriptor.SetViewOriginCoord(inputIndex, concatAxis, mergeDimOrigin);
44     mergeDimOrigin += inputTensorInfo.GetShape()[concatAxis];
45 
46     for (unsigned int j = concatAxis + 1; j < inputRank; ++j)
47     {
48         concatDescriptor.SetViewOriginCoord(inputIndex, j, 0);
49     }
50 }
51 
CalculateReducedOutputTensoInfo(const armnn::TensorInfo & inputTensorInfo,const std::set<unsigned int> & axisSet,bool keepDims,armnn::TensorInfo & outputTensorInfo)52 void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo,
53                                      const std::set<unsigned int>& axisSet,
54                                      bool keepDims,
55                                      armnn::TensorInfo& outputTensorInfo)
56 {
57     std::vector<unsigned int> outputShapeVector;
58     bool dimensionFound = false;
59     unsigned int size = 1;
60 
61     for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
62     {
63         dimensionFound = false;
64         for (unsigned int axis: axisSet)
65         {
66             if (axis == i)
67             {
68                 dimensionFound = true;
69                 break;
70             }
71         }
72 
73         if (!dimensionFound)
74         {
75             size *= inputTensorInfo.GetShape()[i];
76 
77             if (keepDims)
78             {
79                 outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
80             }
81         }
82         else
83         {
84             if (keepDims)
85             {
86                 outputShapeVector.push_back(1);
87             }
88         }
89     }
90 
91     if (keepDims)
92     {
93         armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
94         outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
95     }
96     else
97     {
98         outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
99     }
100 }
101 
102 
CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo & inputTensorInfo,const armnn::StridedSliceDescriptor & desc,armnn::TensorInfo & outputTensorInfo)103 void CalculateStridedSliceOutputTensorInfo(const armnn::TensorInfo& inputTensorInfo,
104                                            const armnn::StridedSliceDescriptor& desc,
105                                            armnn::TensorInfo& outputTensorInfo)
106 {
107     const armnn::TensorShape& inputShape = inputTensorInfo.GetShape();
108 
109     std::vector<unsigned int> outputShapeVector;
110     for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
111     {
112         if (desc.m_ShrinkAxisMask & (1 << i))
113         {
114             continue;
115         }
116 
117         int stride = desc.m_Stride[i];
118         int start = desc.GetStartForAxis(inputShape, i);
119         int stop = desc.GetStopForAxis(inputShape, i, start);
120 
121         int newSize = stride > 0 ? ((stop - start) + stride - 1) / stride :
122                       ((start - stop) - stride - 1) / -stride;
123 
124         newSize = std::max(0, newSize);
125 
126         outputShapeVector.push_back(static_cast<unsigned int>(newSize));
127     }
128 
129     armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
130     outputTensorInfo = armnn::TensorInfo(armnn::TensorShape(outputTensorShape), inputTensorInfo.GetDataType());
131 }
132 } // namespace armnnUtils
133