1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefElementwiseWorkload.hpp"
7
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "Profiling.hpp"
12 #include "RefWorkloadUtils.hpp"
13 #include "StringMapping.hpp"
14 #include <ResolveType.hpp>
15 #include <vector>
16
17 namespace armnn
18 {
19
20 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
RefElementwiseWorkload(const ParentDescriptor & desc,const WorkloadInfo & info)21 RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWorkload(
22 const ParentDescriptor& desc,
23 const WorkloadInfo& info)
24 : BaseWorkload<ParentDescriptor>(desc, info)
25 {
26 }
27
28 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
PostAllocationConfigure()29 void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::PostAllocationConfigure()
30 {
31 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
32 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
33 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
34
35 m_Input0 = MakeDecoder<InType>(inputInfo0);
36 m_Input1 = MakeDecoder<InType>(inputInfo1);
37 m_Output = MakeEncoder<OutType>(outputInfo);
38 }
39
40 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
Execute() const41 void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
42 {
43 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
44 const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
45 const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
46 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
47
48 const TensorShape& inShape0 = inputInfo0.GetShape();
49 const TensorShape& inShape1 = inputInfo1.GetShape();
50 const TensorShape& outShape = outputInfo.GetShape();
51
52 m_Input0->Reset(m_Data.m_Inputs[0]->Map());
53 m_Input1->Reset(m_Data.m_Inputs[1]->Map());
54 m_Output->Reset(m_Data.m_Outputs[0]->Map());
55
56 ElementwiseBinaryFunction<Functor>(inShape0,
57 inShape1,
58 outShape,
59 *m_Input0,
60 *m_Input1,
61 *m_Output);
62 }
63
64 } //namespace armnn
65
66 template class armnn::RefElementwiseWorkload<std::plus<float>,
67 armnn::AdditionQueueDescriptor,
68 armnn::StringMapping::RefAdditionWorkload_Execute>;
69
70 template class armnn::RefElementwiseWorkload<std::plus<int32_t>,
71 armnn::AdditionQueueDescriptor,
72 armnn::StringMapping::RefAdditionWorkload_Execute>;
73
74 template class armnn::RefElementwiseWorkload<std::minus<float>,
75 armnn::SubtractionQueueDescriptor,
76 armnn::StringMapping::RefSubtractionWorkload_Execute>;
77
78 template class armnn::RefElementwiseWorkload<std::minus<int32_t>,
79 armnn::SubtractionQueueDescriptor,
80 armnn::StringMapping::RefSubtractionWorkload_Execute>;
81
82 template class armnn::RefElementwiseWorkload<std::multiplies<float>,
83 armnn::MultiplicationQueueDescriptor,
84 armnn::StringMapping::RefMultiplicationWorkload_Execute>;
85
86 template class armnn::RefElementwiseWorkload<std::multiplies<int32_t>,
87 armnn::MultiplicationQueueDescriptor,
88 armnn::StringMapping::RefMultiplicationWorkload_Execute>;
89
90 template class armnn::RefElementwiseWorkload<std::divides<float>,
91 armnn::DivisionQueueDescriptor,
92 armnn::StringMapping::RefDivisionWorkload_Execute>;
93
94 template class armnn::RefElementwiseWorkload<std::divides<int32_t>,
95 armnn::DivisionQueueDescriptor,
96 armnn::StringMapping::RefDivisionWorkload_Execute>;
97
98 template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
99 armnn::MaximumQueueDescriptor,
100 armnn::StringMapping::RefMaximumWorkload_Execute>;
101
102 template class armnn::RefElementwiseWorkload<armnn::maximum<int32_t>,
103 armnn::MaximumQueueDescriptor,
104 armnn::StringMapping::RefMaximumWorkload_Execute>;
105
106 template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
107 armnn::MinimumQueueDescriptor,
108 armnn::StringMapping::RefMinimumWorkload_Execute>;
109
110 template class armnn::RefElementwiseWorkload<armnn::minimum<int32_t>,
111 armnn::MinimumQueueDescriptor,
112 armnn::StringMapping::RefMinimumWorkload_Execute>;
113