• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefElementwiseWorkload.hpp"
7 
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "Profiling.hpp"
12 #include "RefWorkloadUtils.hpp"
13 #include "StringMapping.hpp"
14 #include <ResolveType.hpp>
15 #include <vector>
16 
17 namespace armnn
18 {
19 
20 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
RefElementwiseWorkload(const ParentDescriptor & desc,const WorkloadInfo & info)21 RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWorkload(
22     const ParentDescriptor& desc,
23     const WorkloadInfo& info)
24     : BaseWorkload<ParentDescriptor>(desc, info)
25 {
26 }
27 
28 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
PostAllocationConfigure()29 void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::PostAllocationConfigure()
30 {
31     const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
32     const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
33     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
34 
35     m_Input0 = MakeDecoder<InType>(inputInfo0);
36     m_Input1 = MakeDecoder<InType>(inputInfo1);
37     m_Output = MakeEncoder<OutType>(outputInfo);
38 }
39 
40 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
Execute() const41 void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
42 {
43     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
44     const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
45     const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
46     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
47 
48     const TensorShape& inShape0 = inputInfo0.GetShape();
49     const TensorShape& inShape1 = inputInfo1.GetShape();
50     const TensorShape& outShape = outputInfo.GetShape();
51 
52     m_Input0->Reset(m_Data.m_Inputs[0]->Map());
53     m_Input1->Reset(m_Data.m_Inputs[1]->Map());
54     m_Output->Reset(m_Data.m_Outputs[0]->Map());
55 
56     ElementwiseBinaryFunction<Functor>(inShape0,
57                                        inShape1,
58                                        outShape,
59                                        *m_Input0,
60                                        *m_Input1,
61                                        *m_Output);
62 }
63 
64 } //namespace armnn
65 
66 template class armnn::RefElementwiseWorkload<std::plus<float>,
67                                             armnn::AdditionQueueDescriptor,
68                                             armnn::StringMapping::RefAdditionWorkload_Execute>;
69 
70 template class armnn::RefElementwiseWorkload<std::plus<int32_t>,
71                                             armnn::AdditionQueueDescriptor,
72                                             armnn::StringMapping::RefAdditionWorkload_Execute>;
73 
74 template class armnn::RefElementwiseWorkload<std::minus<float>,
75                                             armnn::SubtractionQueueDescriptor,
76                                             armnn::StringMapping::RefSubtractionWorkload_Execute>;
77 
78 template class armnn::RefElementwiseWorkload<std::minus<int32_t>,
79                                             armnn::SubtractionQueueDescriptor,
80                                             armnn::StringMapping::RefSubtractionWorkload_Execute>;
81 
82 template class armnn::RefElementwiseWorkload<std::multiplies<float>,
83                                             armnn::MultiplicationQueueDescriptor,
84                                             armnn::StringMapping::RefMultiplicationWorkload_Execute>;
85 
86 template class armnn::RefElementwiseWorkload<std::multiplies<int32_t>,
87                                             armnn::MultiplicationQueueDescriptor,
88                                             armnn::StringMapping::RefMultiplicationWorkload_Execute>;
89 
90 template class armnn::RefElementwiseWorkload<std::divides<float>,
91                                             armnn::DivisionQueueDescriptor,
92                                             armnn::StringMapping::RefDivisionWorkload_Execute>;
93 
94 template class armnn::RefElementwiseWorkload<std::divides<int32_t>,
95                                             armnn::DivisionQueueDescriptor,
96                                             armnn::StringMapping::RefDivisionWorkload_Execute>;
97 
98 template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
99                                             armnn::MaximumQueueDescriptor,
100                                             armnn::StringMapping::RefMaximumWorkload_Execute>;
101 
102 template class armnn::RefElementwiseWorkload<armnn::maximum<int32_t>,
103                                             armnn::MaximumQueueDescriptor,
104                                             armnn::StringMapping::RefMaximumWorkload_Execute>;
105 
106 template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
107                                             armnn::MinimumQueueDescriptor,
108                                             armnn::StringMapping::RefMinimumWorkload_Execute>;
109 
110 template class armnn::RefElementwiseWorkload<armnn::minimum<int32_t>,
111                                             armnn::MinimumQueueDescriptor,
112                                             armnn::StringMapping::RefMinimumWorkload_Execute>;
113