1 // 2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "IWorkload.hpp" 8 #include "WorkloadData.hpp" 9 #include "WorkloadInfo.hpp" 10 #include "WorkingMemDescriptor.hpp" 11 #include "ExecutionData.hpp" 12 13 #include <armnn/Logging.hpp> 14 15 #include <Profiling.hpp> 16 17 #include <client/include/IProfilingService.hpp> 18 19 #include <algorithm> 20 21 namespace armnn 22 { 23 24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template 25 // in the various workload factories. 26 // There should never be an instantiation of a NullWorkload. 27 class NullWorkload : public IWorkload 28 { 29 NullWorkload()=delete; 30 }; 31 32 template <typename QueueDescriptor> 33 class BaseWorkload : public IWorkload 34 { 35 public: 36 BaseWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)37 BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 38 : m_Data(descriptor), 39 m_Guid(arm::pipe::IProfilingService::GetNextGuid()) 40 { 41 m_Data.Validate(info); 42 } 43 ExecuteAsync(ExecutionData & executionData)44 void ExecuteAsync(ExecutionData& executionData) override 45 { 46 ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance"; 47 #if !defined(ARMNN_DISABLE_THREADS) 48 std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex); 49 #endif 50 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data); 51 m_Data.m_Inputs = workingMemDescriptor->m_Inputs; 52 m_Data.m_Outputs = workingMemDescriptor->m_Outputs; 53 54 Execute(); 55 }; 56 PostAllocationConfigure()57 void PostAllocationConfigure() override {} 58 GetData() const59 const QueueDescriptor& GetData() const { return m_Data; } 60 GetGuid() const61 arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; } 62 SupportsTensorHandleReplacement() const63 virtual bool SupportsTensorHandleReplacement() const override 64 { 65 return false; 66 } 67 68 // Replace input tensor handle with the given TensorHandle ReplaceInputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)69 void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override 70 { 71 armnn::IgnoreUnused(tensorHandle, slot); 72 throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload"); 73 } 74 75 // Replace output tensor handle with the given TensorHandle ReplaceOutputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)76 void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override 77 { 78 armnn::IgnoreUnused(tensorHandle, slot); 79 throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload"); 80 } 81 82 protected: 83 QueueDescriptor m_Data; 84 const arm::pipe::ProfilingGuid m_Guid; 85 86 private: 87 #if !defined(ARMNN_DISABLE_THREADS) 88 std::mutex m_AsyncWorkloadMutex; 89 #endif 90 }; 91 92 // TypedWorkload used 93 template <typename QueueDescriptor, armnn::DataType... DataTypes> 94 class TypedWorkload : public BaseWorkload<QueueDescriptor> 95 { 96 public: 97 TypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)98 TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 99 : BaseWorkload<QueueDescriptor>(descriptor, info) 100 { 101 std::vector<armnn::DataType> dataTypes = {DataTypes...}; 102 armnn::DataType expectedInputType; 103 104 if (!info.m_InputTensorInfos.empty()) 105 { 106 expectedInputType = info.m_InputTensorInfos.front().GetDataType(); 107 108 if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end()) 109 { 110 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 111 } 112 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()), 113 info.m_InputTensorInfos.end(), 114 [&](auto it){ 115 return it.GetDataType() == expectedInputType; 116 }), 117 "Trying to create workload with incorrect type"); 118 } 119 armnn::DataType expectedOutputType; 120 121 if (!info.m_OutputTensorInfos.empty()) 122 { 123 expectedOutputType = info.m_OutputTensorInfos.front().GetDataType(); 124 125 if (!info.m_InputTensorInfos.empty()) 126 { 127 expectedInputType = info.m_InputTensorInfos.front().GetDataType(); 128 129 if (expectedOutputType != expectedInputType) 130 { 131 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 132 } 133 } 134 else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end()) 135 { 136 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type"); 137 } 138 ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()), 139 info.m_OutputTensorInfos.end(), 140 [&](auto it){ 141 return it.GetDataType() == expectedOutputType; 142 }), 143 "Trying to create workload with incorrect type"); 144 } 145 } 146 }; 147 148 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType> 149 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor> 150 { 151 public: 152 MultiTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)153 MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 154 : BaseWorkload<QueueDescriptor>(descriptor, info) 155 { 156 ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(), 157 info.m_InputTensorInfos.end(), 158 [&](auto it){ 159 return it.GetDataType() == InputDataType; 160 }), 161 "Trying to create workload with incorrect type"); 162 163 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), 164 info.m_OutputTensorInfos.end(), 165 [&](auto it){ 166 return it.GetDataType() == OutputDataType; 167 }), 168 "Trying to create workload with incorrect type"); 169 } 170 }; 171 172 // FirstInputTypedWorkload used to check type of the first input 173 template <typename QueueDescriptor, armnn::DataType DataType> 174 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor> 175 { 176 public: 177 FirstInputTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)178 FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) 179 : BaseWorkload<QueueDescriptor>(descriptor, info) 180 { 181 if (!info.m_InputTensorInfos.empty()) 182 { 183 ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType, 184 "Trying to create workload with incorrect type"); 185 } 186 187 ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(), 188 info.m_OutputTensorInfos.end(), 189 [&](auto it){ 190 return it.GetDataType() == DataType; 191 }), 192 "Trying to create workload with incorrect type"); 193 } 194 }; 195 196 template <typename QueueDescriptor> 197 using FloatWorkload = TypedWorkload<QueueDescriptor, 198 armnn::DataType::Float16, 199 armnn::DataType::Float32>; 200 201 template <typename QueueDescriptor> 202 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>; 203 204 template <typename QueueDescriptor> 205 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>; 206 207 template <typename QueueDescriptor> 208 using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>; 209 210 template <typename QueueDescriptor> 211 using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>; 212 213 template <typename QueueDescriptor> 214 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor, 215 armnn::DataType::Float32, 216 armnn::DataType::Boolean>; 217 218 template <typename QueueDescriptor> 219 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor, 220 armnn::DataType::QAsymmU8, 221 armnn::DataType::Boolean>; 222 223 template <typename QueueDescriptor> 224 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 225 armnn::DataType::BFloat16, 226 armnn::DataType::Float32>; 227 228 template <typename QueueDescriptor> 229 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor, 230 armnn::DataType::Float32, 231 armnn::DataType::BFloat16>; 232 233 template <typename QueueDescriptor> 234 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 235 armnn::DataType::Float16, 236 armnn::DataType::Float32>; 237 238 template <typename QueueDescriptor> 239 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor, 240 armnn::DataType::Float32, 241 armnn::DataType::Float16>; 242 243 template <typename QueueDescriptor> 244 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor, 245 armnn::DataType::QAsymmU8, 246 armnn::DataType::Float32>; 247 248 } //namespace armnn 249