• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "IWorkload.hpp"
8 #include "WorkloadData.hpp"
9 #include "WorkloadInfo.hpp"
10 #include "WorkingMemDescriptor.hpp"
11 #include "ExecutionData.hpp"
12 
13 #include <armnn/Logging.hpp>
14 
15 #include <Profiling.hpp>
16 
17 #include <client/include/IProfilingService.hpp>
18 
19 #include <algorithm>
20 
21 namespace armnn
22 {
23 
24 // NullWorkload used to denote an unsupported workload when used by the MakeWorkload<> template
25 // in the various workload factories.
26 // There should never be an instantiation of a NullWorkload.
27 class NullWorkload : public IWorkload
28 {
29     NullWorkload()=delete;
30 };
31 
32 template <typename QueueDescriptor>
33 class BaseWorkload : public IWorkload
34 {
35 public:
36 
BaseWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)37     BaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
38         : m_Data(descriptor),
39           m_Guid(arm::pipe::IProfilingService::GetNextGuid())
40     {
41         m_Data.Validate(info);
42     }
43 
ExecuteAsync(ExecutionData & executionData)44     void ExecuteAsync(ExecutionData& executionData) override
45     {
46         ARMNN_LOG(info) << "Using default async workload execution, this will network affect performance";
47 #if !defined(ARMNN_DISABLE_THREADS)
48         std::lock_guard<std::mutex> lockGuard(m_AsyncWorkloadMutex);
49 #endif
50         WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
51         m_Data.m_Inputs = workingMemDescriptor->m_Inputs;
52         m_Data.m_Outputs = workingMemDescriptor->m_Outputs;
53 
54         Execute();
55     };
56 
PostAllocationConfigure()57     void PostAllocationConfigure() override {}
58 
GetData() const59     const QueueDescriptor& GetData() const { return m_Data; }
60 
GetGuid() const61     arm::pipe::ProfilingGuid GetGuid() const final { return m_Guid; }
62 
SupportsTensorHandleReplacement() const63     virtual bool SupportsTensorHandleReplacement()  const override
64     {
65         return false;
66     }
67 
68     // Replace input tensor handle with the given TensorHandle
ReplaceInputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)69     void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
70     {
71         armnn::IgnoreUnused(tensorHandle, slot);
72         throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
73     }
74 
75     // Replace output tensor handle with the given TensorHandle
ReplaceOutputTensorHandle(ITensorHandle * tensorHandle,unsigned int slot)76     void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
77     {
78         armnn::IgnoreUnused(tensorHandle, slot);
79         throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
80     }
81 
82 protected:
83     QueueDescriptor m_Data;
84     const arm::pipe::ProfilingGuid m_Guid;
85 
86 private:
87 #if !defined(ARMNN_DISABLE_THREADS)
88     std::mutex m_AsyncWorkloadMutex;
89 #endif
90 };
91 
92 // TypedWorkload used
93 template <typename QueueDescriptor, armnn::DataType... DataTypes>
94 class TypedWorkload : public BaseWorkload<QueueDescriptor>
95 {
96 public:
97 
TypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)98     TypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
99         : BaseWorkload<QueueDescriptor>(descriptor, info)
100     {
101         std::vector<armnn::DataType> dataTypes = {DataTypes...};
102         armnn::DataType expectedInputType;
103 
104         if (!info.m_InputTensorInfos.empty())
105         {
106             expectedInputType = info.m_InputTensorInfos.front().GetDataType();
107 
108             if (std::find(dataTypes.begin(), dataTypes.end(), expectedInputType) == dataTypes.end())
109             {
110                 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
111             }
112             ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_InputTensorInfos.begin()),
113                                          info.m_InputTensorInfos.end(),
114                                          [&](auto it){
115                                              return it.GetDataType() == expectedInputType;
116                                          }),
117                              "Trying to create workload with incorrect type");
118         }
119         armnn::DataType expectedOutputType;
120 
121         if (!info.m_OutputTensorInfos.empty())
122         {
123             expectedOutputType = info.m_OutputTensorInfos.front().GetDataType();
124 
125             if (!info.m_InputTensorInfos.empty())
126             {
127                 expectedInputType = info.m_InputTensorInfos.front().GetDataType();
128 
129                 if (expectedOutputType != expectedInputType)
130                 {
131                     ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
132                 }
133             }
134             else if (std::find(dataTypes.begin(), dataTypes.end(), expectedOutputType) == dataTypes.end())
135             {
136                 ARMNN_ASSERT_MSG(false, "Trying to create workload with incorrect type");
137             }
138             ARMNN_ASSERT_MSG(std::all_of(std::next(info.m_OutputTensorInfos.begin()),
139                                          info.m_OutputTensorInfos.end(),
140                                          [&](auto it){
141                                              return it.GetDataType() == expectedOutputType;
142                                          }),
143                              "Trying to create workload with incorrect type");
144         }
145     }
146 };
147 
148 template <typename QueueDescriptor, armnn::DataType InputDataType, armnn::DataType OutputDataType>
149 class MultiTypedWorkload : public BaseWorkload<QueueDescriptor>
150 {
151 public:
152 
MultiTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)153     MultiTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
154         : BaseWorkload<QueueDescriptor>(descriptor, info)
155     {
156         ARMNN_ASSERT_MSG(std::all_of(info.m_InputTensorInfos.begin(),
157                                      info.m_InputTensorInfos.end(),
158                                      [&](auto it){
159                                          return it.GetDataType() == InputDataType;
160                                      }),
161                          "Trying to create workload with incorrect type");
162 
163         ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
164                                      info.m_OutputTensorInfos.end(),
165                                      [&](auto it){
166                                          return it.GetDataType() == OutputDataType;
167                                      }),
168                          "Trying to create workload with incorrect type");
169     }
170 };
171 
172 // FirstInputTypedWorkload used to check type of the first input
173 template <typename QueueDescriptor, armnn::DataType DataType>
174 class FirstInputTypedWorkload : public BaseWorkload<QueueDescriptor>
175 {
176 public:
177 
FirstInputTypedWorkload(const QueueDescriptor & descriptor,const WorkloadInfo & info)178     FirstInputTypedWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
179         : BaseWorkload<QueueDescriptor>(descriptor, info)
180     {
181         if (!info.m_InputTensorInfos.empty())
182         {
183             ARMNN_ASSERT_MSG(info.m_InputTensorInfos.front().GetDataType() == DataType,
184                                  "Trying to create workload with incorrect type");
185         }
186 
187         ARMNN_ASSERT_MSG(std::all_of(info.m_OutputTensorInfos.begin(),
188                                      info.m_OutputTensorInfos.end(),
189                                      [&](auto it){
190                                          return it.GetDataType() == DataType;
191                                      }),
192                          "Trying to create workload with incorrect type");
193     }
194 };
195 
196 template <typename QueueDescriptor>
197 using FloatWorkload = TypedWorkload<QueueDescriptor,
198                                     armnn::DataType::Float16,
199                                     armnn::DataType::Float32>;
200 
201 template <typename QueueDescriptor>
202 using Float32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Float32>;
203 
204 template <typename QueueDescriptor>
205 using Uint8Workload = TypedWorkload<QueueDescriptor, armnn::DataType::QAsymmU8>;
206 
207 template <typename QueueDescriptor>
208 using Int32Workload = TypedWorkload<QueueDescriptor, armnn::DataType::Signed32>;
209 
210 template <typename QueueDescriptor>
211 using BooleanWorkload = TypedWorkload<QueueDescriptor, armnn::DataType::Boolean>;
212 
213 template <typename QueueDescriptor>
214 using BaseFloat32ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
215                                                          armnn::DataType::Float32,
216                                                          armnn::DataType::Boolean>;
217 
218 template <typename QueueDescriptor>
219 using BaseUint8ComparisonWorkload = MultiTypedWorkload<QueueDescriptor,
220                                                        armnn::DataType::QAsymmU8,
221                                                        armnn::DataType::Boolean>;
222 
223 template <typename QueueDescriptor>
224 using BFloat16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
225                                                      armnn::DataType::BFloat16,
226                                                      armnn::DataType::Float32>;
227 
228 template <typename QueueDescriptor>
229 using Float32ToBFloat16Workload = MultiTypedWorkload<QueueDescriptor,
230                                                      armnn::DataType::Float32,
231                                                      armnn::DataType::BFloat16>;
232 
233 template <typename QueueDescriptor>
234 using Float16ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
235                                                     armnn::DataType::Float16,
236                                                     armnn::DataType::Float32>;
237 
238 template <typename QueueDescriptor>
239 using Float32ToFloat16Workload = MultiTypedWorkload<QueueDescriptor,
240                                                     armnn::DataType::Float32,
241                                                     armnn::DataType::Float16>;
242 
243 template <typename QueueDescriptor>
244 using Uint8ToFloat32Workload = MultiTypedWorkload<QueueDescriptor,
245                                                   armnn::DataType::QAsymmU8,
246                                                   armnn::DataType::Float32>;
247 
248 } //namespace armnn
249