1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Gather.hpp"
7
8 #include "RefWorkloadUtils.hpp"
9
10 #include <backendsCommon/WorkloadData.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 #include <armnn/utility/NumericCast.hpp>
13
14 namespace armnn
15 {
16
Gather(const TensorInfo & paramsInfo,const TensorInfo & indicesInfo,const TensorInfo & outputInfo,Decoder<float> & params,const int32_t * indices,Encoder<float> & output,const int32_t axis)17 void Gather(const TensorInfo& paramsInfo,
18 const TensorInfo& indicesInfo,
19 const TensorInfo& outputInfo,
20 Decoder<float>& params,
21 const int32_t* indices,
22 Encoder<float>& output,
23 const int32_t axis)
24 {
25 IgnoreUnused(outputInfo);
26 IgnoreUnused(axis);
27
28 const TensorShape& paramsShape = paramsInfo.GetShape();
29
30 unsigned int paramsProduct = 1;
31 for (unsigned int i = 1; i < paramsInfo.GetNumDimensions(); ++i)
32 {
33 paramsProduct = paramsProduct * paramsShape[i];
34 }
35
36 unsigned int outIndex = 0;
37 for (unsigned int i = 0; i < indicesInfo.GetNumElements(); ++i)
38 {
39 unsigned int indx = armnn::numeric_cast<unsigned int>(indices[i]);
40
41 ARMNN_ASSERT(indices[i] >= 0 && indx < paramsShape[0]);
42
43 unsigned int startOffset = indx * paramsProduct;
44 unsigned int endOffset = startOffset + paramsProduct;
45
46 for (unsigned int j = startOffset; j < endOffset; ++j)
47 {
48 params[j];
49 float outputValue = params.Get();
50 output[outIndex];
51 output.Set(outputValue);
52 ++outIndex;
53 }
54 }
55
56 ARMNN_ASSERT(outIndex == outputInfo.GetNumElements());
57 }
58
59 } //namespace armnn
60