• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ArgMinMax.hpp"
7 
8 #include <armnnUtils/TensorUtils.hpp>
9 
10 #include <armnn/utility/NumericCast.hpp>
11 
12 namespace armnn
13 {
14 
15 template <typename OUT>
ArgMinMax(Decoder<float> & in,OUT * out,const TensorInfo & inputTensorInfo,const TensorInfo & outputTensorInfo,ArgMinMaxFunction function,int axis)16 void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
17                const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
18 {
19     IgnoreUnused(outputTensorInfo);
20 
21     unsigned int uAxis = armnnUtils::GetUnsignedAxis(inputTensorInfo.GetNumDimensions(), axis);
22 
23     const unsigned int outerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(), 0, uAxis);
24     const unsigned int axisSize = inputTensorInfo.GetShape()[uAxis];
25     const unsigned int innerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(),
26                                                                          uAxis + 1,
27                                                                          inputTensorInfo.GetNumDimensions());
28 
29     for (unsigned int outer = 0; outer < outerElements; ++outer) {
30         for (unsigned int inner = 0; inner < innerElements; ++inner) {
31             in[outer * axisSize * innerElements + inner];
32             auto tmpValue = in.Get();
33             unsigned int tmpIndex = 0;
34             for (unsigned int i = 1; i < axisSize; ++i) {
35                 in[(outer * axisSize * innerElements) + (i * innerElements) + inner];
36                 const auto& value = in.Get();
37                 if ((function == armnn::ArgMinMaxFunction::Min && value < tmpValue) ||
38                     (function == armnn::ArgMinMaxFunction::Max &&  value > tmpValue)) {
39                     tmpValue = value;
40                     tmpIndex = i;
41                 }
42             }
43 
44             out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
45         }
46     }
47 }
48 
49 template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
50                const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
51 
52 template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
53                const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
54 
55 } //namespace armnn
56