• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include <Layer.hpp>
6 #include <backendsCommon/CpuTensorHandle.hpp>
7 #include <backendsCommon/MemCopyWorkload.hpp>
8 #include <backendsCommon/MemImportWorkload.hpp>
9 #include <backendsCommon/MakeWorkloadHelper.hpp>
10 #include <reference/workloads/RefFillWorkload.hpp>
11 #include "RefWorkloadFactory.hpp"
12 #include "RefBackendId.hpp"
13 #include "workloads/RefWorkloads.hpp"
14 #include "RefTensorHandle.hpp"
15 
16 
17 namespace armnn
18 {
19 
20 namespace
21 {
22 static const BackendId s_Id{RefBackendId()};
23 }
24 template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
MakeWorkload(const QueueDescriptorType & descriptor,const WorkloadInfo & info) const25 std::unique_ptr<IWorkload> RefWorkloadFactory::MakeWorkload(const QueueDescriptorType& descriptor,
26                                                             const WorkloadInfo& info) const
27 {
28     return MakeWorkloadHelper<NullWorkload, F32Workload, U8Workload, NullWorkload, NullWorkload, NullWorkload>
29            (descriptor, info);
30 }
31 
32 template <DataType ArmnnType>
IsDataType(const WorkloadInfo & info)33 bool IsDataType(const WorkloadInfo& info)
34 {
35     auto checkType = [](const TensorInfo& tensorInfo) {return tensorInfo.GetDataType() == ArmnnType;};
36     auto it = std::find_if(std::begin(info.m_InputTensorInfos), std::end(info.m_InputTensorInfos), checkType);
37     if (it != std::end(info.m_InputTensorInfos))
38     {
39         return true;
40     }
41     it = std::find_if(std::begin(info.m_OutputTensorInfos), std::end(info.m_OutputTensorInfos), checkType);
42     if (it != std::end(info.m_OutputTensorInfos))
43     {
44         return true;
45     }
46     return false;
47 }
48 
IsSigned32(const WorkloadInfo & info)49 bool IsSigned32(const WorkloadInfo& info)
50 {
51     return IsDataType<DataType::Signed32>(info);
52 }
53 
IsBFloat16(const WorkloadInfo & info)54 bool IsBFloat16(const WorkloadInfo& info)
55 {
56     return IsDataType<DataType::BFloat16>(info);
57 }
58 
IsFloat16(const WorkloadInfo & info)59 bool IsFloat16(const WorkloadInfo& info)
60 {
61     return IsDataType<DataType::Float16>(info);
62 }
63 
IsQSymmS16(const WorkloadInfo & info)64 bool IsQSymmS16(const WorkloadInfo& info)
65 {
66     return IsDataType<DataType::QSymmS16>(info);
67 }
68 
IsQSymmS8(const WorkloadInfo & info)69 bool IsQSymmS8(const WorkloadInfo& info)
70 {
71     return IsDataType<DataType::QSymmS8>(info);
72 }
73 
IsQAsymmS8(const WorkloadInfo & info)74 bool IsQAsymmS8(const WorkloadInfo& info)
75 {
76     return IsDataType<DataType::QAsymmS8>(info);
77 }
78 
IsQAsymmU8(const WorkloadInfo & info)79 bool IsQAsymmU8(const WorkloadInfo& info)
80 {
81     return IsDataType<DataType::QAsymmU8>(info);
82 }
83 
RefWorkloadFactory(const std::shared_ptr<RefMemoryManager> & memoryManager)84 RefWorkloadFactory::RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager)
85     : m_MemoryManager(memoryManager)
86 {
87 }
88 
RefWorkloadFactory()89 RefWorkloadFactory::RefWorkloadFactory()
90     : m_MemoryManager(new RefMemoryManager())
91 {
92 }
93 
GetBackendId() const94 const BackendId& RefWorkloadFactory::GetBackendId() const
95 {
96     return s_Id;
97 }
98 
IsLayerSupported(const Layer & layer,Optional<DataType> dataType,std::string & outReasonIfUnsupported)99 bool RefWorkloadFactory::IsLayerSupported(const Layer& layer,
100                                           Optional<DataType> dataType,
101                                           std::string& outReasonIfUnsupported)
102 {
103     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported);
104 }
105 
IsLayerSupported(const IConnectableLayer & layer,Optional<DataType> dataType,std::string & outReasonIfUnsupported,const ModelOptions & modelOptions)106 bool RefWorkloadFactory::IsLayerSupported(const IConnectableLayer& layer,
107                                           Optional<DataType> dataType,
108                                           std::string& outReasonIfUnsupported,
109                                           const ModelOptions& modelOptions)
110 {
111     return IWorkloadFactory::IsLayerSupported(s_Id, layer, dataType, outReasonIfUnsupported, modelOptions);
112 }
113 
CreateTensorHandle(const TensorInfo & tensorInfo,const bool isMemoryManaged) const114 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
115                                                                       const bool isMemoryManaged) const
116 {
117     // For Ref it is okay to make the TensorHandle memory managed as it can also store a pointer
118     // to unmanaged memory. This also ensures memory alignment.
119     IgnoreUnused(isMemoryManaged);
120     return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
121 }
122 
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool isMemoryManaged) const123 std::unique_ptr<ITensorHandle> RefWorkloadFactory::CreateTensorHandle(const TensorInfo& tensorInfo,
124                                                                       DataLayout dataLayout,
125                                                                       const bool isMemoryManaged) const
126 {
127     // For Ref it is okay to make the TensorHandle memory managed as it can also store a pointer
128     // to unmanaged memory. This also ensures memory alignment.
129     IgnoreUnused(isMemoryManaged, dataLayout);
130     return std::make_unique<RefTensorHandle>(tensorInfo, m_MemoryManager);
131 }
132 
CreateAbs(const AbsQueueDescriptor & descriptor,const WorkloadInfo & info) const133 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
134                                                          const WorkloadInfo& info) const
135 {
136     IgnoreUnused(descriptor);
137     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
138     elementwiseUnaryDescriptor.m_Parameters.m_Operation = UnaryOperation::Abs;
139 
140     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
141 }
142 
CreateActivation(const ActivationQueueDescriptor & descriptor,const WorkloadInfo & info) const143 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
144                                                                 const WorkloadInfo& info) const
145 {
146     return std::make_unique<RefActivationWorkload>(descriptor, info);
147 }
148 
CreateAddition(const AdditionQueueDescriptor & descriptor,const WorkloadInfo & info) const149 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
150                                                               const WorkloadInfo& info) const
151 {
152     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
153     {
154         return std::make_unique<RefAdditionWorkload<int32_t>>(descriptor, info);
155     }
156     else
157     {
158         return std::make_unique<RefAdditionWorkload<float>>(descriptor, info);
159     }
160 }
161 
CreateArgMinMax(const ArgMinMaxQueueDescriptor & descriptor,const WorkloadInfo & info) const162 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
163                                                                const WorkloadInfo& info) const
164 {
165     return std::make_unique<RefArgMinMaxWorkload>(descriptor, info);
166 }
167 
CreateBatchNormalization(const BatchNormalizationQueueDescriptor & descriptor,const WorkloadInfo & info) const168 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchNormalization(
169     const BatchNormalizationQueueDescriptor& descriptor,
170     const WorkloadInfo& info) const
171 {
172     return std::make_unique<RefBatchNormalizationWorkload>(descriptor, info);
173 }
174 
CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor & descriptor,const WorkloadInfo & info) const175 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
176                                                                     const WorkloadInfo& info) const
177 {
178     return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info);
179 }
180 
CreateComparison(const ComparisonQueueDescriptor & descriptor,const WorkloadInfo & info) const181 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
182                                                                 const WorkloadInfo& info) const
183 {
184     return std::make_unique<RefComparisonWorkload>(descriptor, info);
185 }
186 
CreateConcat(const ConcatQueueDescriptor & descriptor,const WorkloadInfo & info) const187 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
188                                                             const WorkloadInfo& info) const
189 {
190     return std::make_unique<RefConcatWorkload>(descriptor, info);
191 }
192 
CreateConstant(const ConstantQueueDescriptor & descriptor,const WorkloadInfo & info) const193 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
194                                                               const WorkloadInfo& info) const
195 {
196     return std::make_unique<RefConstantWorkload>(descriptor, info);
197 }
198 
CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor & descriptor,const WorkloadInfo & info) const199 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertBf16ToFp32(
200     const ConvertBf16ToFp32QueueDescriptor& descriptor,
201     const WorkloadInfo& info) const
202 {
203     return std::make_unique<RefConvertBf16ToFp32Workload>(descriptor, info);
204 }
205 
CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor & descriptor,const WorkloadInfo & info) const206 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp16ToFp32(
207     const ConvertFp16ToFp32QueueDescriptor& descriptor,
208     const WorkloadInfo& info) const
209 {
210     return std::make_unique<RefConvertFp16ToFp32Workload>(descriptor, info);
211 }
212 
CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor & descriptor,const WorkloadInfo & info) const213 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToBf16(
214     const ConvertFp32ToBf16QueueDescriptor& descriptor,
215     const WorkloadInfo& info) const
216 {
217     return std::make_unique<RefConvertFp32ToBf16Workload>(descriptor, info);
218 }
219 
CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor & descriptor,const WorkloadInfo & info) const220 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16(
221     const ConvertFp32ToFp16QueueDescriptor& descriptor,
222     const WorkloadInfo& info) const
223 {
224     return std::make_unique<RefConvertFp32ToFp16Workload>(descriptor, info);
225 }
226 
CreateConvolution2d(const Convolution2dQueueDescriptor & descriptor,const WorkloadInfo & info) const227 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
228                                                                    const WorkloadInfo& info) const
229 {
230     return std::make_unique<RefConvolution2dWorkload>(descriptor, info);
231 }
232 
CreateDebug(const DebugQueueDescriptor & descriptor,const WorkloadInfo & info) const233 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
234                                                            const WorkloadInfo& info) const
235 {
236     if (IsBFloat16(info))
237     {
238         return std::make_unique<RefDebugBFloat16Workload>(descriptor, info);
239     }
240     if (IsFloat16(info))
241     {
242         return std::make_unique<RefDebugFloat16Workload>(descriptor, info);
243     }
244     if (IsQSymmS16(info))
245     {
246         return std::make_unique<RefDebugQSymmS16Workload>(descriptor, info);
247     }
248     if (IsQSymmS8(info))
249     {
250         return std::make_unique<RefDebugQSymmS8Workload>(descriptor, info);
251     }
252     if (IsQAsymmU8(info))
253     {
254         return std::make_unique<RefDebugQAsymmU8Workload>(descriptor, info);
255     }
256     if (IsQAsymmS8(info))
257     {
258         return std::make_unique<RefDebugQAsymmS8Workload>(descriptor, info);
259     }
260     if (IsSigned32(info))
261     {
262         return std::make_unique<RefDebugSigned32Workload>(descriptor, info);
263     }
264 
265     return MakeWorkload<RefDebugFloat32Workload, RefDebugQAsymmU8Workload>(descriptor, info);
266 }
267 
CreateDepthToSpace(const DepthToSpaceQueueDescriptor & descriptor,const WorkloadInfo & info) const268 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
269                                                                   const WorkloadInfo& info) const
270 {
271     return std::make_unique<RefDepthToSpaceWorkload>(descriptor, info);
272 }
273 
CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor & descriptor,const WorkloadInfo & info) const274 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDepthwiseConvolution2d(
275     const DepthwiseConvolution2dQueueDescriptor& descriptor,
276     const WorkloadInfo& info) const
277 {
278     return std::make_unique<RefDepthwiseConvolution2dWorkload>(descriptor, info);
279 }
280 
CreateDequantize(const DequantizeQueueDescriptor & descriptor,const WorkloadInfo & info) const281 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDequantize(const DequantizeQueueDescriptor& descriptor,
282                                                                 const WorkloadInfo& info) const
283 {
284     return std::make_unique<RefDequantizeWorkload>(descriptor, info);
285 }
286 
CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor & descriptor,const WorkloadInfo & info) const287 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDetectionPostProcess(
288     const DetectionPostProcessQueueDescriptor& descriptor,
289     const WorkloadInfo& info) const
290 {
291     return std::make_unique<RefDetectionPostProcessWorkload>(descriptor, info);
292 }
293 
CreateDivision(const DivisionQueueDescriptor & descriptor,const WorkloadInfo & info) const294 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
295                                                               const WorkloadInfo& info) const
296 {
297     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
298     {
299         return std::make_unique<RefDivisionWorkload<int32_t>>(descriptor, info);
300     }
301     else
302     {
303         return std::make_unique<RefDivisionWorkload<float>>(descriptor, info);
304     }
305 }
306 
CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor & descriptor,const WorkloadInfo & info) const307 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
308                                                                       const WorkloadInfo& info) const
309 {
310     if (descriptor.m_Parameters.m_Operation == UnaryOperation::LogicalNot)
311     {
312         return std::make_unique<RefLogicalUnaryWorkload>(descriptor, info);
313     }
314     return std::make_unique<RefElementwiseUnaryWorkload>(descriptor, info);
315 }
316 
CreateEqual(const EqualQueueDescriptor & descriptor,const WorkloadInfo & info) const317 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
318                                                            const WorkloadInfo& info) const
319 {
320     IgnoreUnused(descriptor);
321     ComparisonQueueDescriptor comparisonDescriptor;
322     comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Equal;
323 
324     return CreateComparison(comparisonDescriptor, info);
325 }
326 
CreateFakeQuantization(const FakeQuantizationQueueDescriptor & descriptor,const WorkloadInfo & info) const327 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
328                                                                       const WorkloadInfo& info) const
329 {
330     return MakeWorkload<RefFakeQuantizationFloat32Workload, NullWorkload>(descriptor, info);
331 }
332 
CreateFill(const FillQueueDescriptor & descriptor,const WorkloadInfo & info) const333 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFill(const FillQueueDescriptor& descriptor,
334                                                           const WorkloadInfo& info) const
335 {
336     return std::make_unique<RefFillWorkload>(descriptor, info);
337 }
338 
CreateFloor(const FloorQueueDescriptor & descriptor,const WorkloadInfo & info) const339 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
340                                                            const WorkloadInfo& info) const
341 {
342     if(IsQuantizedType(info.m_InputTensorInfos[0].GetDataType()))
343     {
344         return nullptr;
345     }
346     else
347     {
348         return std::make_unique<RefFloorWorkload>(descriptor, info);
349     }
350 }
351 
CreateFullyConnected(const FullyConnectedQueueDescriptor & descriptor,const WorkloadInfo & info) const352 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFullyConnected(
353     const FullyConnectedQueueDescriptor& descriptor,
354     const WorkloadInfo& info) const
355 {
356     return std::make_unique<RefFullyConnectedWorkload>(descriptor, info);
357 }
358 
CreateGather(const GatherQueueDescriptor & descriptor,const WorkloadInfo & info) const359 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
360                                                             const WorkloadInfo& info) const
361 {
362     return std::make_unique<RefGatherWorkload>(descriptor, info);
363 }
364 
CreateGreater(const GreaterQueueDescriptor & descriptor,const WorkloadInfo & info) const365 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
366                                                              const WorkloadInfo& info) const
367 {
368     IgnoreUnused(descriptor);
369     ComparisonQueueDescriptor comparisonDescriptor;
370     comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Greater;
371 
372     return CreateComparison(comparisonDescriptor, info);
373 }
374 
CreateInput(const InputQueueDescriptor & descriptor,const WorkloadInfo & info) const375 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
376                                                            const WorkloadInfo& info) const
377 {
378     if (info.m_InputTensorInfos.empty() )
379     {
380         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Input cannot be zero length");
381     }
382     if (info.m_OutputTensorInfos.empty())
383     {
384         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: Output cannot be zero length");
385     }
386 
387     if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
388     {
389         throw InvalidArgumentException("RefWorkloadFactory::CreateInput: data input and output differ in byte count.");
390     }
391 
392     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
393 }
394 
CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor & descriptor,const WorkloadInfo & info) const395 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInstanceNormalization(
396     const InstanceNormalizationQueueDescriptor& descriptor,
397     const WorkloadInfo& info) const
398 {
399     return std::make_unique<RefInstanceNormalizationWorkload>(descriptor, info);
400 }
401 
CreateL2Normalization(const L2NormalizationQueueDescriptor & descriptor,const WorkloadInfo & info) const402 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
403                                                                      const WorkloadInfo& info) const
404 {
405     return std::make_unique<RefL2NormalizationWorkload>(descriptor, info);
406 }
407 
CreateLogicalBinary(const LogicalBinaryQueueDescriptor & descriptor,const WorkloadInfo & info) const408 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
409                                                                    const WorkloadInfo& info) const
410 {
411     return std::make_unique<RefLogicalBinaryWorkload>(descriptor, info);
412 }
413 
CreateLogSoftmax(const LogSoftmaxQueueDescriptor & descriptor,const WorkloadInfo & info) const414 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
415                                                                 const WorkloadInfo& info) const
416 {
417     return std::make_unique<RefLogSoftmaxWorkload>(descriptor, info);
418 }
419 
CreateLstm(const LstmQueueDescriptor & descriptor,const WorkloadInfo & info) const420 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
421                                                           const WorkloadInfo& info) const
422 {
423     return std::make_unique<RefLstmWorkload>(descriptor, info);
424 }
425 
CreateMaximum(const MaximumQueueDescriptor & descriptor,const WorkloadInfo & info) const426 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
427                                                              const WorkloadInfo& info) const
428 {
429     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
430     {
431         return std::make_unique<RefMaximumWorkload<int32_t>>(descriptor, info);
432     }
433     else
434     {
435         return std::make_unique<RefMaximumWorkload<float>>(descriptor, info);
436     }
437 }
438 
CreateMean(const MeanQueueDescriptor & descriptor,const WorkloadInfo & info) const439 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
440                                                           const WorkloadInfo& info) const
441 {
442     return  std::make_unique<RefMeanWorkload>(descriptor, info);
443 }
444 
CreateMemCopy(const MemCopyQueueDescriptor & descriptor,const WorkloadInfo & info) const445 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
446                                                              const WorkloadInfo& info) const
447 {
448     if (descriptor.m_Inputs.empty())
449     {
450         throw InvalidArgumentException("RefWorkloadFactory: CreateMemCopy() expected an input tensor.");
451     }
452     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
453 }
454 
CreateMemImport(const MemImportQueueDescriptor & descriptor,const WorkloadInfo & info) const455 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
456                                                                const WorkloadInfo& info) const
457 {
458     if (descriptor.m_Inputs.empty())
459     {
460         throw InvalidArgumentException("RefWorkloadFactory: CreateMemImport() expected an input tensor.");
461     }
462     return std::make_unique<ImportMemGenericWorkload>(descriptor, info);
463 }
464 
CreateMerger(const MergerQueueDescriptor & descriptor,const WorkloadInfo & info) const465 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
466                                                             const WorkloadInfo& info) const
467 {
468     return CreateConcat(descriptor, info);
469 }
470 
CreateMinimum(const MinimumQueueDescriptor & descriptor,const WorkloadInfo & info) const471 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
472                                                              const WorkloadInfo& info) const
473 {
474     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
475     {
476         return std::make_unique<RefMinimumWorkload<int32_t>>(descriptor, info);
477     }
478     else
479     {
480         return std::make_unique<RefMinimumWorkload<float>>(descriptor, info);
481     }
482 }
483 
CreateMultiplication(const MultiplicationQueueDescriptor & descriptor,const WorkloadInfo & info) const484 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
485                                                                     const WorkloadInfo& info) const
486 {
487     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
488     {
489         return std::make_unique<RefMultiplicationWorkload<int32_t>>(descriptor, info);
490     }
491     else
492     {
493         return std::make_unique<RefMultiplicationWorkload<float>>(descriptor, info);
494     }
495 }
496 
CreateNormalization(const NormalizationQueueDescriptor & descriptor,const WorkloadInfo & info) const497 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
498                                                                    const WorkloadInfo& info) const
499 {
500     return std::make_unique<RefNormalizationWorkload>(descriptor, info);
501 }
502 
CreateOutput(const OutputQueueDescriptor & descriptor,const WorkloadInfo & info) const503 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
504                                                             const WorkloadInfo& info) const
505 {
506     if (info.m_InputTensorInfos.empty() )
507     {
508         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Input cannot be zero length");
509     }
510     if (info.m_OutputTensorInfos.empty())
511     {
512         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: Output cannot be zero length");
513     }
514     if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
515     {
516         throw InvalidArgumentException("RefWorkloadFactory::CreateOutput: data input and output differ in byte count.");
517     }
518 
519     return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
520 }
521 
CreatePad(const PadQueueDescriptor & descriptor,const WorkloadInfo & info) const522 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
523                                                          const WorkloadInfo& info) const
524 {
525     return std::make_unique<RefPadWorkload>(descriptor, info);
526 }
527 
CreatePermute(const PermuteQueueDescriptor & descriptor,const WorkloadInfo & info) const528 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
529                                                              const WorkloadInfo& info) const
530 {
531     if (IsQSymmS16(info))
532     {
533         return std::make_unique<RefPermuteQSymm16Workload>(descriptor, info);
534     }
535     else if (IsBFloat16(info))
536     {
537         return std::make_unique<RefPermuteBFloat16Workload>(descriptor, info);
538     }
539     else if (IsQAsymmS8(info))
540     {
541         return std::make_unique<RefPermuteQAsymmS8Workload>(descriptor, info);
542     }
543     return MakeWorkloadHelper<RefPermuteFloat16Workload, RefPermuteFloat32Workload, RefPermuteQAsymm8Workload,
544         NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
545 }
546 
CreatePooling2d(const Pooling2dQueueDescriptor & descriptor,const WorkloadInfo & info) const547 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
548                                                                const WorkloadInfo& info) const
549 {
550     return std::make_unique<RefPooling2dWorkload>(descriptor, info);
551 }
552 
CreatePreCompiled(const PreCompiledQueueDescriptor &,const WorkloadInfo &) const553 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
554                                                                  const WorkloadInfo& /*info*/) const
555 {
556     return nullptr;
557 }
558 
CreatePrelu(const PreluQueueDescriptor & descriptor,const WorkloadInfo & info) const559 std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePrelu(const PreluQueueDescriptor& descriptor,
560                                                            const WorkloadInfo& info) const
561 {
562     return std::make_unique<RefPreluWorkload>(descriptor, info);
563 }
564 
CreateQLstm(const QLstmQueueDescriptor & descriptor,const WorkloadInfo & info) const565 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& descriptor,
566                                                            const WorkloadInfo& info) const
567 {
568     return std::make_unique<RefQLstmWorkload>(descriptor, info);
569 }
570 
CreateQuantize(const QuantizeQueueDescriptor & descriptor,const WorkloadInfo & info) const571 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
572                                                               const WorkloadInfo& info) const
573 {
574     return std::make_unique<RefQuantizeWorkload>(descriptor, info);
575 }
576 
CreateRank(const RankQueueDescriptor & descriptor,const WorkloadInfo & info) const577 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor,
578                                                           const WorkloadInfo& info) const
579 {
580     return std::make_unique<RefRankWorkload>(descriptor, info);
581 }
582 
CreateReshape(const ReshapeQueueDescriptor & descriptor,const WorkloadInfo & info) const583 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
584                                                              const WorkloadInfo& info) const
585 {
586     return std::make_unique<RefReshapeWorkload>(descriptor, info);
587 }
588 
CreateResize(const ResizeQueueDescriptor & descriptor,const WorkloadInfo & info) const589 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
590                                                             const WorkloadInfo& info) const
591 {
592     return std::make_unique<RefResizeWorkload>(descriptor, info);
593 }
594 
CreateResizeBilinear(const ResizeBilinearQueueDescriptor & descriptor,const WorkloadInfo & info) const595 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
596                                                                     const WorkloadInfo& info) const
597 {
598     ResizeQueueDescriptor resizeDescriptor;
599     resizeDescriptor.m_Parameters.m_Method       = ResizeMethod::Bilinear;
600     resizeDescriptor.m_Parameters.m_DataLayout   = descriptor.m_Parameters.m_DataLayout;
601     resizeDescriptor.m_Parameters.m_TargetWidth  = descriptor.m_Parameters.m_TargetWidth;
602     resizeDescriptor.m_Parameters.m_TargetHeight = descriptor.m_Parameters.m_TargetHeight;
603 
604     return CreateResize(resizeDescriptor, info);
605 }
606 
CreateRsqrt(const RsqrtQueueDescriptor & descriptor,const WorkloadInfo & info) const607 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
608                                                            const WorkloadInfo& info) const
609 {
610     IgnoreUnused(descriptor);
611     ElementwiseUnaryQueueDescriptor elementwiseUnaryDescriptor;
612     elementwiseUnaryDescriptor.m_Parameters.m_Operation = UnaryOperation::Rsqrt;
613 
614     return CreateElementwiseUnary(elementwiseUnaryDescriptor, info);
615 }
616 
CreateSlice(const SliceQueueDescriptor & descriptor,const WorkloadInfo & info) const617 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSlice(const SliceQueueDescriptor& descriptor,
618                                                            const WorkloadInfo& info) const
619 {
620     return std::make_unique<RefSliceWorkload>(descriptor, info);
621 }
622 
CreateSoftmax(const SoftmaxQueueDescriptor & descriptor,const WorkloadInfo & info) const623 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
624                                                              const WorkloadInfo& info) const
625 {
626     return std::make_unique<RefSoftmaxWorkload>(descriptor, info);
627 }
628 
CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor & descriptor,const WorkloadInfo & info) const629 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
630                                                                     const WorkloadInfo& info) const
631 {
632     return std::make_unique<RefSpaceToBatchNdWorkload>(descriptor, info);
633 }
634 
CreateSpaceToDepth(const SpaceToDepthQueueDescriptor & descriptor,const WorkloadInfo & info) const635 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
636                                                                   const WorkloadInfo& info) const
637 {
638     return std::make_unique<RefSpaceToDepthWorkload>(descriptor, info);
639 }
640 
CreateSplitter(const SplitterQueueDescriptor & descriptor,const WorkloadInfo & info) const641 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
642                                                               const WorkloadInfo& info) const
643 {
644     return std::make_unique<RefSplitterWorkload>(descriptor, info);
645 }
646 
CreateStack(const StackQueueDescriptor & descriptor,const WorkloadInfo & info) const647 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
648                                                            const WorkloadInfo& info) const
649 {
650     return std::make_unique<RefStackWorkload>(descriptor, info);
651 }
652 
CreateStridedSlice(const StridedSliceQueueDescriptor & descriptor,const WorkloadInfo & info) const653 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
654                                                                   const WorkloadInfo& info) const
655 {
656     return std::make_unique<RefStridedSliceWorkload>(descriptor, info);
657 }
658 
CreateSubtraction(const SubtractionQueueDescriptor & descriptor,const WorkloadInfo & info) const659 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
660                                                                  const WorkloadInfo& info) const
661 {
662     if (info.m_InputTensorInfos[0].GetDataType() == armnn::DataType::Signed32)
663     {
664         return std::make_unique<RefSubtractionWorkload<int32_t>>(descriptor, info);
665     }
666     else
667     {
668         return std::make_unique<RefSubtractionWorkload<float>>(descriptor, info);
669     }
670 }
671 
CreateTranspose(const TransposeQueueDescriptor & descriptor,const WorkloadInfo & info) const672 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& descriptor,
673                                                                const WorkloadInfo& info) const
674 {
675     if (IsQSymmS16(info))
676     {
677         return std::make_unique<RefTransposeQSymm16Workload>(descriptor, info);
678     }
679     else if (IsBFloat16(info))
680     {
681         return std::make_unique<RefTransposeBFloat16Workload>(descriptor, info);
682     }
683     else if (IsQAsymmS8(info))
684     {
685         return std::make_unique<RefTransposeQAsymmS8Workload>(descriptor, info);
686     }
687     return MakeWorkloadHelper<RefTransposeFloat16Workload, RefTransposeFloat32Workload, RefTransposeQAsymm8Workload,
688             NullWorkload, NullWorkload, NullWorkload>(descriptor, info);
689 }
690 
CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor & descriptor,const WorkloadInfo & info) const691 std::unique_ptr<IWorkload> RefWorkloadFactory::CreateTransposeConvolution2d(
692     const TransposeConvolution2dQueueDescriptor& descriptor,
693     const WorkloadInfo& info) const
694 {
695     return std::make_unique<RefTransposeConvolution2dWorkload>(descriptor, info);
696 }
697 
698 } // namespace armnn
699