1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "ArgMinMax.hpp"
7
8 #include <armnnUtils/TensorUtils.hpp>
9
10 #include <armnn/utility/NumericCast.hpp>
11
12 namespace armnn
13 {
14
15 template <typename OUT>
ArgMinMax(Decoder<float> & in,OUT * out,const TensorInfo & inputTensorInfo,const TensorInfo & outputTensorInfo,ArgMinMaxFunction function,int axis)16 void ArgMinMax(Decoder<float>& in, OUT* out, const TensorInfo& inputTensorInfo,
17 const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis)
18 {
19 IgnoreUnused(outputTensorInfo);
20
21 unsigned int uAxis = armnnUtils::GetUnsignedAxis(inputTensorInfo.GetNumDimensions(), axis);
22
23 const unsigned int outerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(), 0, uAxis);
24 const unsigned int axisSize = inputTensorInfo.GetShape()[uAxis];
25 const unsigned int innerElements = armnnUtils::GetNumElementsBetween(inputTensorInfo.GetShape(),
26 uAxis + 1,
27 inputTensorInfo.GetNumDimensions());
28
29 for (unsigned int outer = 0; outer < outerElements; ++outer) {
30 for (unsigned int inner = 0; inner < innerElements; ++inner) {
31 in[outer * axisSize * innerElements + inner];
32 auto tmpValue = in.Get();
33 unsigned int tmpIndex = 0;
34 for (unsigned int i = 1; i < axisSize; ++i) {
35 in[(outer * axisSize * innerElements) + (i * innerElements) + inner];
36 const auto& value = in.Get();
37 if ((function == armnn::ArgMinMaxFunction::Min && value < tmpValue) ||
38 (function == armnn::ArgMinMaxFunction::Max && value > tmpValue)) {
39 tmpValue = value;
40 tmpIndex = i;
41 }
42 }
43
44 out[outer * innerElements + inner] = armnn::numeric_cast<OUT>(tmpIndex);
45 }
46 }
47 }
48
49 template void ArgMinMax(Decoder<float>& in, int32_t* out, const TensorInfo& inputTensorInfo,
50 const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
51
52 template void ArgMinMax(Decoder<float>& in, int64_t* out, const TensorInfo& inputTensorInfo,
53 const TensorInfo& outputTensorInfo, ArgMinMaxFunction function, int axis);
54
55 } //namespace armnn
56