1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "WorkloadData.hpp" 8 #include "WorkloadInfo.hpp" 9 10 #include <armnn/backends/IWorkload.hpp> 11 #include <Profiling.hpp> 12 #include <ProfilingService.hpp> 13 14 #include <algorithm> 15 16 namespace armnn 17 { 18 19 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template 20 // in the various workload factories. 21 // There should never be an instantiation of a NullWorkload. 22 class NullWorkload : public IWorkload 23 { 24 NullWorkload()=delete; 25 }; 26 27 template <typename QueueDescriptor> 28 class BaseWorkload : public IWorkload 29 { 30 public: 31 BaseWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)32 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 33 : m_Data(descriptor), 34 m_Guid(profiling::ProfilingService::GetNextGuid()) 35 { 36 m_Data.Validate(info); 37 } 38 PostAllocationConfigure()39 void PostAllocationConfigure() override {} 40 GetData() const41 const QueueDescriptor& GetData() const { return m_Data; } 42 GetGuid() const43 profiling::ProfilingGuid GetGuid() const final { return m_Guid; } 44 45 protected: 46 const QueueDescriptor m_Data; 47 const profiling::ProfilingGuid m_Guid; 48 }; 49 50 // TypedWorkload used 51 template <typename QueueDescriptor, armnn::DataType... DataTypes> 52 class TypedWorkload : public BaseWorkload<QueueDescriptor> 53 { 54 public: 55 TypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)56 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 57 : BaseWorkload<QueueDescriptor>(descriptor, info) 58 { 59 std::vector<armnn::DataType> dataTypes = {DataTypes...}; 60 armnn::DataType expectedInputType; 61 62 if (!info.m_InputTensorInfos.empty()) 63 { 64 expectedInputType = info.m_InputTensorInfos.front().GetDataType(); 65 66 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end()) 67 { 68 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 69 } 70 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()), 71 info.m_InputTensorInfos.end(), 72 [&](auto it){ 73 return it.GetDataType() == expectedInputType; 74 }), 75 "Trying to create workload with incorrect type"); 76 } 77 armnn::DataType expectedOutputType; 78 79 if (!info.m_OutputTensorInfos.empty()) 80 { 81 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType(); 82 83 if (!info.m_InputTensorInfos.empty()) 84 { 85 if (expectedOutputType != expectedInputType) 86 { 87 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 88 } 89 } 90 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end()) 91 { 92 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 93 } 94 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()), 95 info.m_OutputTensorInfos.end(), 96 [&](auto it){ 97 return it.GetDataType() == expectedOutputType; 98 }), 99 "Trying to create workload with incorrect type"); 100 } 101 } 102 }; 103 104 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType> 105 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor> 106 { 107 public: 108 MultiTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)109 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 110 : BaseWorkload<QueueDescriptor>(descriptor, info) 111 { 112 ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(), 113 info.m_InputTensorInfos.end(), 114 [&](auto it){ 115 return it.GetDataType() == InputDataType; 116 }), 117 "Trying to create workload with incorrect type"); 118 119 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), 120 info.m_OutputTensorInfos.end(), 121 [&](auto it){ 122 return it.GetDataType() == OutputDataType; 123 }), 124 "Trying to create workload with incorrect type"); 125 } 126 }; 127 128 // FirstInputTypedWorkload used to check type of the first input 129 template <typename QueueDescriptor, armnn::DataType DataType> 130 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor> 131 { 132 public: 133 FirstInputTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)134 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 135 : BaseWorkload<QueueDescriptor>(descriptor, info) 136 { 137 if (!info.m_InputTensorInfos.empty()) 138 { 139 ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, 140 "Trying to create workload with incorrect type"); 141 } 142 143 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), 144 info.m_OutputTensorInfos.end(), 145 [&](auto it){ 146 return it.GetDataType() == DataType; 147 }), 148 "Trying to create workload with incorrect type"); 149 } 150 }; 151 152 template <typename QueueDescriptor> 153 using FloatWorkload = TypedWorkload<QueueDescriptor, 154 armnn::DataType::Float16, 155 armnn::DataType::Float32>; 156 157 template <typename QueueDescriptor> 158 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>; 159 160 template <typename QueueDescriptor> 161 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>; 162 163 template <typename QueueDescriptor> 164 using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>; 165 166 template <typename QueueDescriptor> 167 using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>; 168 169 template <typename QueueDescriptor> 170 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor, 171 armnn::DataType::Float32, 172 armnn::DataType::Boolean>; 173 174 template <typename QueueDescriptor> 175 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor, 176 armnn::DataType::QAsymmU8, 177 armnn::DataType::Boolean>; 178 179 template <typename QueueDescriptor> 180 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 181 armnn::DataType::BFloat16, 182 armnn::DataType::Float32>; 183 184 template <typename QueueDescriptor> 185 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor, 186 armnn::DataType::Float32, 187 armnn::DataType::BFloat16>; 188 189 template <typename QueueDescriptor> 190 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 191 armnn::DataType::Float16, 192 armnn::DataType::Float32>; 193 194 template <typename QueueDescriptor> 195 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor, 196 armnn::DataType::Float32, 197 armnn::DataType::Float16>; 198 199 template <typename QueueDescriptor> 200 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 201 armnn::DataType::QAsymmU8, 202 armnn::DataType::Float32>; 203 204 } //namespace armnn 205