1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefElementwiseUnaryWorkload.hpp"
7
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "RefWorkloadUtils.hpp"
12 #include "Abs.hpp"
13 #include "Exp.hpp"
14 #include "Rsqrt.hpp"
15 #include "Sqrt.hpp"
16
17 #include <Profiling.hpp>
18
19 #include <armnn/TypesUtils.hpp>
20
21 #include <functional>
22
23 namespace armnn
24 {
25
RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor & desc,const WorkloadInfo & info)26 RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
27 const WorkloadInfo& info)
28 : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
29 {}
30
PostAllocationConfigure()31 void RefElementwiseUnaryWorkload::PostAllocationConfigure()
32 {
33 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
34 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
35
36 m_Input = MakeDecoder<InType>(inputInfo);
37
38 m_Output = MakeEncoder<OutType>(outputInfo);
39 }
40
Execute() const41 void RefElementwiseUnaryWorkload::Execute() const
42 {
43 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
44
45 const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
46 const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
47
48 const TensorShape& inShape = inputInfo.GetShape();
49 const TensorShape& outShape = outputInfo.GetShape();
50
51 m_Input->Reset(m_Data.m_Inputs[0]->Map());
52 m_Output->Reset(m_Data.m_Outputs[0]->Map());
53
54 using AbsFunction = ElementwiseUnaryFunction<abs<InType>>;
55 using ExpFunction = ElementwiseUnaryFunction<exp<InType>>;
56 using NegFunction = ElementwiseUnaryFunction<std::negate<InType>>;
57 using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
58 using SqrtFunction = ElementwiseUnaryFunction<sqrt<InType>>;
59
60 switch (m_Data.m_Parameters.m_Operation)
61 {
62 case UnaryOperation::Abs:
63 {
64 AbsFunction(inShape, outShape, *m_Input, *m_Output);
65 break;
66 }
67 case UnaryOperation::Exp:
68 {
69 ExpFunction(inShape, outShape, *m_Input, *m_Output);
70 break;
71 }
72 case UnaryOperation::Neg:
73 {
74 NegFunction(inShape, outShape, *m_Input, *m_Output);
75 break;
76 }
77 case UnaryOperation::Rsqrt:
78 {
79 RsqrtFunction(inShape, outShape, *m_Input, *m_Output);
80 break;
81 }
82 case UnaryOperation::Sqrt:
83 {
84 SqrtFunction(inShape, outShape, *m_Input, *m_Output);
85 break;
86 }
87 default:
88 {
89 throw InvalidArgumentException(std::string("Unsupported unary operation ") +
90 GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
91 }
92 }
93 }
94
95 } // namespace armnn
96