• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Optional.hpp>
8 #include <armnn/backends/IBackendInternal.hpp>
9 
10 #include <backendsCommon/WorkloadFactoryBase.hpp>
11 #include <aclCommon/BaseMemoryManager.hpp>
12 #include <armnn/utility/IgnoreUnused.hpp>
13 
14 namespace armnn
15 {
16 
17 // Neon workload factory.
18 class NeonWorkloadFactory : public WorkloadFactoryBase
19 {
20 public:
21     NeonWorkloadFactory(const std::shared_ptr<NeonMemoryManager>& memoryManager);
22 
23     NeonWorkloadFactory(const std::shared_ptr<NeonMemoryManager>& memoryManager,
24                         const IBackendInternal::IBackendSpecificModelContextPtr& modelContextPtr);
25 
26     const BackendId& GetBackendId() const override;
27 
28     static bool IsLayerSupported(const Layer& layer,
29                                  Optional<DataType> dataType,
30                                  std::string& outReasonIfUnsupported);
31 
32     static bool IsLayerSupported(const IConnectableLayer& layer,
33                                  Optional<DataType> dataType,
34                                  std::string& outReasonIfUnsupported,
35                                  const ModelOptions& modelOptions);
36 
SupportsSubTensors() const37     bool SupportsSubTensors() const override { return true; }
38 
39     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
40     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
41                                                          TensorShape const& subTensorShape,
42                                                          unsigned int const* subTensorOrigin) const override;
43 
44     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
45     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
46                                                       const bool IsMemoryManaged = true) const override;
47 
48     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
49     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
50                                                       DataLayout dataLayout,
51                                                       const bool IsMemoryManaged = true) const override;
52 
53     ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
54     std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor,
55                                          const WorkloadInfo& info) const override;
56 
57     std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor,
58                                                 const WorkloadInfo& info) const override;
59 
60     std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
61                                               const WorkloadInfo& info) const override;
62 
63     std::unique_ptr<IWorkload> CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
64                                                const WorkloadInfo& info) const override;
65 
66     std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
67                                                         const WorkloadInfo& info) const override;
68 
69     std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
70                                                     const WorkloadInfo& Info) const override;
71 
72     std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
73                                                 const WorkloadInfo& Info) const override;
74 
75     std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
76                                             const WorkloadInfo& info) const override;
77 
78     std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
79                                               const WorkloadInfo& info) const override;
80 
81     std::unique_ptr<IWorkload> CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& descriptor,
82                                                        const WorkloadInfo& info) const override;
83 
84     std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
85                                                        const WorkloadInfo& info) const override;
86 
87     std::unique_ptr<IWorkload> CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& descriptor,
88                                                        const WorkloadInfo& info) const override;
89 
90     std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
91                                                        const WorkloadInfo& info) const override;
92 
93     std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
94                                                    const WorkloadInfo& info) const override;
95 
96     std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor,
97                                            const WorkloadInfo& info) const override;
98 
99     std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
100                                                   const WorkloadInfo& info) const override;
101 
102     std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor,
103                                                             const WorkloadInfo& info) const override;
104 
105     std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor,
106                                                 const WorkloadInfo& info) const override;
107 
108     std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor,
109                                                           const WorkloadInfo& info) const override;
110 
111     std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
112                                               const WorkloadInfo& info) const override;
113 
114     std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
115                                                       const WorkloadInfo& Info) const override;
116 
117     ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
118     std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
119                                            const WorkloadInfo& info) const override;
120 
121     std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor,
122                                           const WorkloadInfo& info) const override;
123 
124     std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
125                                            const WorkloadInfo& info) const override;
126 
127     std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
128                                                     const WorkloadInfo& info) const override;
129 
130     std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
131                                             const WorkloadInfo& info) const override;
132 
133     ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
134     std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
135                                              const WorkloadInfo& info) const override;
136 
137     std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
138                                            const WorkloadInfo& info) const override;
139 
140     std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor,
141                                                            const WorkloadInfo& info) const override;
142 
143     std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
144                                                      const WorkloadInfo& info) const override;
145 
146     std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
147                                                    const WorkloadInfo& info) const override;
148 
149     std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
150                                                 const WorkloadInfo& info) const override;
151 
152     std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
153                                           const WorkloadInfo& info) const override;
154 
155     std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor,
156                                              const WorkloadInfo& info) const override;
157 
158     std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
159                                           const WorkloadInfo& Info) const override;
160 
161     std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
162                                              const WorkloadInfo& info) const override;
163 
164     std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
165                                                const WorkloadInfo& info) const override;
166 
167     ARMNN_DEPRECATED_MSG("Use CreateConcat instead")
168     std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
169                                             const WorkloadInfo& info) const override;
170 
171     std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor,
172                                              const WorkloadInfo& info) const override;
173 
174     std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
175                                                     const WorkloadInfo& info) const override;
176 
177     std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
178                                                    const WorkloadInfo& info) const override;
179 
180     std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
181                                             const WorkloadInfo& info) const override;
182 
183     std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
184                                          const WorkloadInfo& info) const override;
185 
186     std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
187                                              const WorkloadInfo& info) const override;
188 
189     std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
190                                                const WorkloadInfo& info) const override;
191 
192     std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
193                                                  const WorkloadInfo& info) const override;
194 
195     std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
196                                            const WorkloadInfo& info) const override;
197 
198     std::unique_ptr<IWorkload> CreateQLstm(const QLstmQueueDescriptor& descriptor,
199                                            const WorkloadInfo& info) const override;
200 
201     std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
202                                               const WorkloadInfo& info) const override;
203 
204     std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
205                                                    const WorkloadInfo& info) const override;
206 
207     std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
208                                              const WorkloadInfo& info) const override;
209 
210     std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor,
211                                             const WorkloadInfo& info) const override;
212 
213     ARMNN_DEPRECATED_MSG("Use CreateResize instead")
214     std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
215                                                     const WorkloadInfo& info) const override;
216 
217     ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
218     std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
219                                            const WorkloadInfo& info) const override;
220 
221     std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor,
222                                            const WorkloadInfo& info) const override;
223 
224     std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
225                                              const WorkloadInfo& info) const override;
226 
227     std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
228                                                     const WorkloadInfo& info) const override;
229 
230     std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
231                                                   const WorkloadInfo& info) const override;
232 
233     std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
234                                               const WorkloadInfo& info) const override;
235 
236     std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor,
237                                            const WorkloadInfo& info) const override;
238 
239     std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
240                                                   const WorkloadInfo& info) const override;
241 
242     std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
243                                                  const WorkloadInfo& info) const override;
244 
245     std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor,
246                                                const WorkloadInfo& info) const override;
247 
248     std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
249                                                             const WorkloadInfo& info) const override;
250 
251 private:
252     mutable std::shared_ptr<NeonMemoryManager> m_MemoryManager;
253     const IBackendInternal::IBackendSpecificModelContextPtr m_ModelContextPtr;
254 };
255 
256 } // namespace armnn
257