1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Mean.hpp"
7 #include <backendsCommon/WorkloadData.hpp>
8
9 #include <armnn/utility/NumericCast.hpp>
10
11 #include <cmath>
12 #include <cstddef>
13 #include <functional>
14 #include <limits>
15
16 namespace armnn
17 {
NextIndex(const unsigned int numDims,const armnn::TensorShape & dims,std::vector<unsigned int> & current)18 bool NextIndex(const unsigned int numDims, const armnn::TensorShape& dims, std::vector<unsigned int>& current)
19 {
20 unsigned int carry = 1;
21
22 for (unsigned int idx = numDims; idx-- > 0; )
23 {
24 unsigned int current_val = current[idx] + carry;
25 if (dims[idx] == current_val)
26 {
27 current[idx] = 0;
28 }
29 else
30 {
31 current[idx] = current_val;
32 carry = 0;
33 break;
34 }
35 }
36 return (carry == 0);
37 }
38
ReducedOutputOffset(const unsigned int numDims,const armnn::TensorShape & dims,std::vector<unsigned int> & index,const unsigned int numAxis,const std::vector<unsigned int> & axis)39 unsigned int ReducedOutputOffset(const unsigned int numDims,
40 const armnn::TensorShape& dims,
41 std::vector<unsigned int>& index,
42 const unsigned int numAxis,
43 const std::vector<unsigned int>& axis)
44 {
45 unsigned int offset = 0;
46 for (unsigned int idx = 0; idx < numDims; ++idx)
47 {
48 bool isAxis = false;
49 if (!axis.empty())
50 {
51 for (unsigned int axisIdx = 0; axisIdx < numAxis; ++axisIdx)
52 {
53 if (idx == axis[axisIdx])
54 {
55 isAxis = true;
56 break;
57 }
58 }
59 }
60 if (!isAxis)
61 {
62 offset = offset * dims[idx] + index[idx];
63 }
64 }
65 return offset;
66 }
67 } // namespace
68
69 namespace armnn
70 {
Mean(const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & outputInfo,const std::vector<unsigned int> & axis,Decoder<float> & input,Encoder<float> & output)71 void Mean(const armnn::TensorInfo& inputInfo,
72 const armnn::TensorInfo& outputInfo,
73 const std::vector<unsigned int>& axis,
74 Decoder<float>& input,
75 Encoder<float>& output)
76 {
77
78 unsigned int inputNumDims = inputInfo.GetNumDimensions();
79 unsigned int outputNumDims = outputInfo.GetNumDimensions();
80
81 armnn::TensorShape outputDims = outputInfo.GetShape();
82 armnn::TensorShape inputDims = inputInfo.GetShape();
83
84 // Initialise output data.
85 unsigned int numOutputs = 1;
86 for (unsigned int idx = 0; idx < outputNumDims; ++idx)
87 {
88 numOutputs *= outputDims[idx];
89 }
90
91 std::vector<float> tempSum(numOutputs);
92 for (unsigned int idx = 0; idx < numOutputs; ++idx)
93 {
94 output[idx];
95 output.Set(0.0f);
96 tempSum[idx] = 0.0f;
97 }
98
99 // Initialise temp index.
100 std::vector<unsigned int> tempIndex(inputNumDims);
101 for (unsigned int idx = 0; idx < inputNumDims; ++idx)
102 {
103 tempIndex[idx] = 0;
104 }
105
106 std::vector<unsigned int> resolvedAxis = axis;
107 if (resolvedAxis.empty())
108 {
109 for (unsigned int idx = 0; idx < inputNumDims; ++idx)
110 {
111 resolvedAxis.push_back(idx);
112 }
113 }
114 auto numResolvedAxis = armnn::numeric_cast<unsigned int>(resolvedAxis.size());
115
116 // Iterates through input_data and sum up the reduced axis.
117 for (bool hasNext = true; hasNext; hasNext = NextIndex(inputNumDims, inputDims, tempIndex))
118 {
119 unsigned int inputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex, 0, {});
120 unsigned int outputOffset = ReducedOutputOffset(inputNumDims, inputDims, tempIndex,
121 numResolvedAxis, resolvedAxis);
122 input[inputOffset];
123 tempSum[outputOffset] += input.Get();
124 }
125
126 // Takes average by num of elements added to get mean.
127 size_t numElementsInAxis = 1;
128 for (unsigned int idx = 0; idx < numResolvedAxis; ++idx)
129 {
130 unsigned int current = inputDims[resolvedAxis[idx]];
131 ARMNN_ASSERT(armnn::numeric_cast<float>(current) <
132 (std::numeric_limits<float>::max() / armnn::numeric_cast<float>(numElementsInAxis)));
133 numElementsInAxis *= current;
134 }
135 if (numElementsInAxis > 0) {
136 for (unsigned int idx = 0; idx < numOutputs; ++idx)
137 {
138 output[idx];
139 output.Set(tempSum[idx] / armnn::numeric_cast<float>(numElementsInAxis));
140 }
141 }
142 }
143 } //namespace armnn
144