• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "RefMemoryManager.hpp"
8 
9 #include <armnn/Optional.hpp>
10 #include <backendsCommon/WorkloadFactory.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12 
13 
14 namespace armnn
15 {
16 
17 template <typename QueueDescriptorType>
IsOperationQueueDescriptor(const QueueDescriptorType &)18 constexpr bool IsOperationQueueDescriptor(const QueueDescriptorType&) { return true; }
19 
20 template <>
IsOperationQueueDescriptor(const MemCopyQueueDescriptor &)21 constexpr bool IsOperationQueueDescriptor(const MemCopyQueueDescriptor&) { return false; }
22 
23 template <>
IsOperationQueueDescriptor(const ConstantQueueDescriptor &)24 constexpr bool IsOperationQueueDescriptor(const ConstantQueueDescriptor&) { return false; }
25 
26 template <>
IsOperationQueueDescriptor(const PermuteQueueDescriptor &)27 constexpr bool IsOperationQueueDescriptor(const PermuteQueueDescriptor&) { return false; }
28 
29 // Reference workload factory.
30 class RefWorkloadFactory : public IWorkloadFactory
31 {
32 public:
33     explicit RefWorkloadFactory(const std::shared_ptr<RefMemoryManager>& memoryManager);
34     RefWorkloadFactory();
35 
~RefWorkloadFactory()36     ~RefWorkloadFactory() {}
37 
38     const BackendId& GetBackendId() const override;
39 
40     static bool IsLayerSupported(const Layer& layer,
41                                  Optional<DataType> dataType,
42                                  std::string& outReasonIfUnsupported);
43 
44     static bool IsLayerSupported(const IConnectableLayer& layer,
45                                  Optional<DataType> dataType,
46                                  std::string& outReasonIfUnsupported,
47                                  const ModelOptions& modelOptions);
48 
SupportsSubTensors() const49     bool SupportsSubTensors() const override { return false; }
50 
51     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
CreateSubTensorHandle(ITensorHandle & parent,TensorShape const & subTensorShape,unsigned int const * subTensorOrigin) const52     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
53                                                          TensorShape const& subTensorShape,
54                                                          unsigned int const* subTensorOrigin) const override
55     {
56         IgnoreUnused(parent, subTensorShape, subTensorOrigin);
57         return nullptr;
58     }
59 
60     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
61     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
62                                                       const bool IsMemoryManaged = true) const override;
63 
64     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
65     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
66                                                       DataLayout dataLayout,
67                                                       const bool IsMemoryManaged = true) const override;
68 
69     ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
70     std::unique_ptr<IWorkload> CreateAbs(const AbsQueueDescriptor& descriptor,
71                                          const WorkloadInfo& info) const override;
72 
73     std::unique_ptr<IWorkload> CreateActivation(const ActivationQueueDescriptor& descriptor,
74                                                 const WorkloadInfo& info) const override;
75 
76     std::unique_ptr<IWorkload> CreateAddition(const AdditionQueueDescriptor& descriptor,
77                                               const WorkloadInfo& info) const override;
78 
79     std::unique_ptr<IWorkload> CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
80                                                const WorkloadInfo& info) const override;
81 
82     std::unique_ptr<IWorkload> CreateBatchNormalization(const BatchNormalizationQueueDescriptor& descriptor,
83                                                         const WorkloadInfo& info) const override;
84 
85     std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
86                                                     const WorkloadInfo& info) const override;
87 
88     std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
89                                                 const WorkloadInfo& info) const override;
90 
91     std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
92                                             const WorkloadInfo& info) const override;
93 
94     std::unique_ptr<IWorkload> CreateConstant(const ConstantQueueDescriptor& descriptor,
95                                               const WorkloadInfo& info) const override;
96 
97     std::unique_ptr<IWorkload> CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& descriptor,
98                                                        const WorkloadInfo& info) const override;
99 
100     std::unique_ptr<IWorkload> CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
101                                                        const WorkloadInfo& info) const override;
102 
103     std::unique_ptr<IWorkload> CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& descriptor,
104                                                        const WorkloadInfo& info) const override;
105 
106     std::unique_ptr<IWorkload> CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
107                                                        const WorkloadInfo& info) const override;
108 
109     std::unique_ptr<IWorkload> CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
110                                                    const WorkloadInfo& info) const override;
111 
112     std::unique_ptr<IWorkload> CreateDebug(const DebugQueueDescriptor& descriptor,
113                                            const WorkloadInfo& info) const override;
114 
115     std::unique_ptr<IWorkload> CreateDepthToSpace(const DepthToSpaceQueueDescriptor& descriptor,
116                                                   const WorkloadInfo& info) const override;
117 
118     std::unique_ptr<IWorkload> CreateDepthwiseConvolution2d(const DepthwiseConvolution2dQueueDescriptor& descriptor,
119                                                             const WorkloadInfo& info) const override;
120 
121     std::unique_ptr<IWorkload> CreateDequantize(const DequantizeQueueDescriptor& descriptor,
122                                                 const WorkloadInfo& info) const override;
123 
124     std::unique_ptr<IWorkload> CreateDetectionPostProcess(const DetectionPostProcessQueueDescriptor& descriptor,
125                                                           const WorkloadInfo& info) const override;
126 
127     std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
128                                               const WorkloadInfo& info) const override;
129 
130     std::unique_ptr<IWorkload> CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& descriptor,
131                                                       const WorkloadInfo& info) const override;
132 
133     ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
134     std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
135                                            const WorkloadInfo& info) const override;
136 
137     std::unique_ptr<IWorkload> CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
138                                                       const WorkloadInfo& info) const override;
139 
140     std::unique_ptr<IWorkload> CreateFill(const FillQueueDescriptor& descriptor,
141                                           const WorkloadInfo& info) const override;
142 
143     std::unique_ptr<IWorkload> CreateFloor(const FloorQueueDescriptor& descriptor,
144                                            const WorkloadInfo& info) const override;
145 
146     std::unique_ptr<IWorkload> CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
147                                                     const WorkloadInfo& info) const override;
148 
149     std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
150                                             const WorkloadInfo& info) const override;
151 
152     ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
153     std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
154                                              const WorkloadInfo& info) const override;
155 
156     std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
157                                            const WorkloadInfo& info) const override;
158 
159     std::unique_ptr<IWorkload> CreateInstanceNormalization(const InstanceNormalizationQueueDescriptor& descriptor,
160                                                            const WorkloadInfo& info) const override;
161 
162     std::unique_ptr<IWorkload> CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
163                                                      const WorkloadInfo& info) const override;
164 
165     std::unique_ptr<IWorkload> CreateLogicalBinary(const LogicalBinaryQueueDescriptor& descriptor,
166                                                    const WorkloadInfo& info) const override;
167 
168     std::unique_ptr<IWorkload> CreateLogSoftmax(const LogSoftmaxQueueDescriptor& descriptor,
169                                                 const WorkloadInfo& info) const override;
170 
171     std::unique_ptr<IWorkload> CreateLstm(const LstmQueueDescriptor& descriptor,
172                                           const WorkloadInfo& info) const override;
173 
174     std::unique_ptr<IWorkload> CreateMaximum(const MaximumQueueDescriptor& descriptor,
175                                              const WorkloadInfo& info) const override;
176 
177     std::unique_ptr<IWorkload> CreateMean(const MeanQueueDescriptor& descriptor,
178                                           const WorkloadInfo& Info) const override;
179 
180     std::unique_ptr<IWorkload> CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
181                                              const WorkloadInfo& info) const override;
182 
183     std::unique_ptr<IWorkload> CreateMemImport(const MemImportQueueDescriptor& descriptor,
184                                                 const WorkloadInfo& info) const override;
185 
186     ARMNN_DEPRECATED_MSG("Use CreateConcat instead")
187     std::unique_ptr<IWorkload> CreateMerger(const MergerQueueDescriptor& descriptor,
188                                             const WorkloadInfo& info) const override;
189 
190     std::unique_ptr<IWorkload> CreateMinimum(const MinimumQueueDescriptor& descriptor,
191                                              const WorkloadInfo& info) const override;
192 
193     std::unique_ptr<IWorkload> CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
194                                                     const WorkloadInfo& info) const override;
195 
196     std::unique_ptr<IWorkload> CreateNormalization(const NormalizationQueueDescriptor& descriptor,
197                                                    const WorkloadInfo& info) const override;
198 
199     std::unique_ptr<IWorkload> CreateOutput(const OutputQueueDescriptor& descriptor,
200                                             const WorkloadInfo& info) const override;
201 
202     std::unique_ptr<IWorkload> CreatePad(const PadQueueDescriptor& descriptor,
203                                          const WorkloadInfo& info) const override;
204 
205     std::unique_ptr<IWorkload> CreatePermute(const PermuteQueueDescriptor& descriptor,
206                                              const WorkloadInfo& info) const override;
207 
208     std::unique_ptr<IWorkload> CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
209                                                const WorkloadInfo& info) const override;
210 
211     std::unique_ptr<IWorkload> CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
212                                                  const WorkloadInfo& info) const override;
213 
214     std::unique_ptr<IWorkload> CreatePrelu(const PreluQueueDescriptor& descriptor,
215                                            const WorkloadInfo& info) const override;
216 
217     std::unique_ptr<IWorkload> CreateQLstm(const QLstmQueueDescriptor& descriptor,
218                                            const WorkloadInfo& info) const override;
219 
220     std::unique_ptr<IWorkload> CreateQuantize(const QuantizeQueueDescriptor& descriptor,
221                                               const WorkloadInfo& info) const override;
222 
223     std::unique_ptr<IWorkload> CreateRank(const RankQueueDescriptor& descriptor,
224                                           const WorkloadInfo& info) const override;
225 
226     std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor,
227                                              const WorkloadInfo& info) const override;
228 
229     std::unique_ptr<IWorkload> CreateResize(const ResizeQueueDescriptor& descriptor,
230                                             const WorkloadInfo& info) const override;
231 
232     ARMNN_DEPRECATED_MSG("Use CreateResize instead")
233     std::unique_ptr<IWorkload> CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
234                                                     const WorkloadInfo& info) const override;
235 
236     ARMNN_DEPRECATED_MSG("Use CreateElementwiseUnary instead")
237     std::unique_ptr<IWorkload> CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
238                                            const WorkloadInfo& info) const override;
239 
240     std::unique_ptr<IWorkload> CreateSlice(const SliceQueueDescriptor& descriptor,
241                                            const WorkloadInfo& info) const override;
242 
243     std::unique_ptr<IWorkload> CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
244                                              const WorkloadInfo& info) const override;
245 
246     std::unique_ptr<IWorkload> CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
247                                                     const WorkloadInfo& info) const override;
248 
249     std::unique_ptr<IWorkload> CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
250                                                   const WorkloadInfo& info) const override;
251 
252     std::unique_ptr<IWorkload> CreateSplitter(const SplitterQueueDescriptor& descriptor,
253                                               const WorkloadInfo& info) const override;
254 
255     std::unique_ptr<IWorkload> CreateStack(const StackQueueDescriptor& descriptor,
256                                            const WorkloadInfo& info) const override;
257 
258     std::unique_ptr<IWorkload> CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
259                                                   const WorkloadInfo& info) const override;
260 
261     std::unique_ptr<IWorkload> CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
262                                                  const WorkloadInfo& info) const override;
263 
264     std::unique_ptr<IWorkload> CreateTranspose(const TransposeQueueDescriptor& descriptor,
265                                                const WorkloadInfo& info) const override;
266 
267     std::unique_ptr<IWorkload> CreateTransposeConvolution2d(const TransposeConvolution2dQueueDescriptor& descriptor,
268                                                             const WorkloadInfo& info) const override;
269 
270 private:
271     template <typename F32Workload, typename U8Workload, typename QueueDescriptorType>
272     std::unique_ptr<IWorkload> MakeWorkload(const QueueDescriptorType& descriptor, const WorkloadInfo& info) const;
273 
274     mutable std::shared_ptr<RefMemoryManager> m_MemoryManager;
275 };
276 
277 } // namespace armnn
278