1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/Types.hpp> 9 #include "RefBaseWorkload.hpp" 10 #include <armnn/backends/WorkloadData.hpp> 11 #include "BaseIterator.hpp" 12 #include "ElementwiseFunction.hpp" 13 #include "Maximum.hpp" 14 #include "Minimum.hpp" 15 #include "StringMapping.hpp" 16 17 namespace armnn 18 { 19 20 template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> 21 class RefElementwiseWorkload : public RefBaseWorkload<ParentDescriptor> 22 { 23 public: 24 RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info); 25 void Execute() const override; 26 void ExecuteAsync(ExecutionData& executionData) override; 27 28 private: 29 using InType = typename ElementwiseBinaryFunction<Functor>::InType; 30 using OutType = typename ElementwiseBinaryFunction<Functor>::OutType; 31 using RefBaseWorkload<ParentDescriptor>::m_Data; 32 33 void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const; 34 }; 35 36 template <typename DataType = float> 37 using RefAdditionWorkload = 38 RefElementwiseWorkload<std::plus<DataType>, 39 AdditionQueueDescriptor, 40 StringMapping::RefAdditionWorkload_Execute>; 41 42 template <typename DataType = float> 43 using RefSubtractionWorkload = 44 RefElementwiseWorkload<std::minus<DataType>, 45 SubtractionQueueDescriptor, 46 StringMapping::RefSubtractionWorkload_Execute>; 47 48 template <typename DataType = float> 49 using RefMultiplicationWorkload = 50 RefElementwiseWorkload<std::multiplies<DataType>, 51 MultiplicationQueueDescriptor, 52 StringMapping::RefMultiplicationWorkload_Execute>; 53 54 template <typename DataType = float> 55 using RefDivisionWorkload = 56 RefElementwiseWorkload<std::divides<DataType>, 57 DivisionQueueDescriptor, 58 StringMapping::RefDivisionWorkload_Execute>; 59 60 template <typename DataType = float> 61 using RefMaximumWorkload = 62 RefElementwiseWorkload<armnn::maximum<DataType>, 63 MaximumQueueDescriptor, 64 StringMapping::RefMaximumWorkload_Execute>; 65 66 template <typename DataType = float> 67 using RefMinimumWorkload = 68 RefElementwiseWorkload<armnn::minimum<DataType>, 69 MinimumQueueDescriptor, 70 StringMapping::RefMinimumWorkload_Execute>; 71 72 } // armnn 73