1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 #include <backendsCommon/Workload.hpp> 10 #include <backendsCommon/WorkloadData.hpp> 11 #include "BaseIterator.hpp" 12 #include "ElementwiseFunction.hpp" 13 #include "Maximum.hpp" 14 #include "Minimum.hpp" 15 #include "StringMapping.hpp" 16 17 namespace armnn 18 { 19 20 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> 21 class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor> 22 { 23 public: 24 using InType = typename ElementwiseBinaryFunction<Functor>::InType; 25 using OutType = typename ElementwiseBinaryFunction<Functor>::OutType; 26 using BaseWorkload<ParentDescriptor>::m_Data; 27 28 RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info); 29 void PostAllocationConfigure() override; 30 void Execute() const override; 31 32 private: 33 std::unique_ptr<Decoder<InType>> m_Input0; 34 std::unique_ptr<Decoder<InType>> m_Input1; 35 std::unique_ptr<Encoder<OutType>> m_Output; 36 }; 37 38 template <typename DataType = float> 39 using RefAdditionWorkload = 40 RefElementwiseWorkload<std::plus<DataType>, 41 AdditionQueueDescriptor, 42 StringMapping::RefAdditionWorkload_Execute>; 43 44 template <typename DataType = float> 45 using RefSubtractionWorkload = 46 RefElementwiseWorkload<std::minus<DataType>, 47 SubtractionQueueDescriptor, 48 StringMapping::RefSubtractionWorkload_Execute>; 49 50 template <typename DataType = float> 51 using RefMultiplicationWorkload = 52 RefElementwiseWorkload<std::multiplies<DataType>, 53 MultiplicationQueueDescriptor, 54 StringMapping::RefMultiplicationWorkload_Execute>; 55 56 template <typename DataType = float> 57 using RefDivisionWorkload = 58 RefElementwiseWorkload<std::divides<DataType>, 59 DivisionQueueDescriptor, 60 StringMapping::RefDivisionWorkload_Execute>; 61 62 template <typename DataType = float> 63 using RefMaximumWorkload = 64 RefElementwiseWorkload<armnn::maximum<DataType>, 65 MaximumQueueDescriptor, 66 StringMapping::RefMaximumWorkload_Execute>; 67 68 template <typename DataType = float> 69 using RefMinimumWorkload = 70 RefElementwiseWorkload<armnn::minimum<DataType>, 71 MinimumQueueDescriptor, 72 StringMapping::RefMinimumWorkload_Execute>; 73 74 } // armnn 75