• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefElementwiseUnaryWorkload.hpp"
7 
8 #include "Decoders.hpp"
9 #include "ElementwiseFunction.hpp"
10 #include "Encoders.hpp"
11 #include "RefWorkloadUtils.hpp"
12 #include "Abs.hpp"
13 #include "Exp.hpp"
14 #include "Rsqrt.hpp"
15 #include "Sqrt.hpp"
16 
17 #include <Profiling.hpp>
18 
19 #include <armnn/TypesUtils.hpp>
20 
21 #include <functional>
22 
23 namespace armnn
24 {
25 
RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor & desc,const WorkloadInfo & info)26 RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
27                                                          const WorkloadInfo& info)
28     : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
29 {}
30 
PostAllocationConfigure()31 void RefElementwiseUnaryWorkload::PostAllocationConfigure()
32 {
33     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
34     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
35 
36     m_Input = MakeDecoder<InType>(inputInfo);
37 
38     m_Output = MakeEncoder<OutType>(outputInfo);
39 }
40 
Execute() const41 void RefElementwiseUnaryWorkload::Execute() const
42 {
43     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefElementwiseUnaryWorkload_Execute");
44 
45     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
46     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
47 
48     const TensorShape& inShape = inputInfo.GetShape();
49     const TensorShape& outShape = outputInfo.GetShape();
50 
51     m_Input->Reset(m_Data.m_Inputs[0]->Map());
52     m_Output->Reset(m_Data.m_Outputs[0]->Map());
53 
54     using AbsFunction   = ElementwiseUnaryFunction<abs<InType>>;
55     using ExpFunction   = ElementwiseUnaryFunction<exp<InType>>;
56     using NegFunction   = ElementwiseUnaryFunction<std::negate<InType>>;
57     using RsqrtFunction = ElementwiseUnaryFunction<rsqrt<InType>>;
58     using SqrtFunction  = ElementwiseUnaryFunction<sqrt<InType>>;
59 
60     switch (m_Data.m_Parameters.m_Operation)
61     {
62         case UnaryOperation::Abs:
63         {
64             AbsFunction(inShape, outShape, *m_Input, *m_Output);
65             break;
66         }
67         case UnaryOperation::Exp:
68         {
69             ExpFunction(inShape, outShape, *m_Input, *m_Output);
70             break;
71         }
72         case UnaryOperation::Neg:
73         {
74             NegFunction(inShape, outShape, *m_Input, *m_Output);
75             break;
76         }
77         case UnaryOperation::Rsqrt:
78         {
79             RsqrtFunction(inShape, outShape, *m_Input, *m_Output);
80             break;
81         }
82         case UnaryOperation::Sqrt:
83         {
84             SqrtFunction(inShape, outShape, *m_Input, *m_Output);
85             break;
86         }
87         default:
88         {
89             throw InvalidArgumentException(std::string("Unsupported unary operation ") +
90                 GetUnaryOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
91         }
92     }
93 }
94 
95 } // namespace armnn
96