• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Deprecated.hpp>
8 #include <armnn/DescriptorsFwd.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/Optional.hpp>
11 #include <armnn/QuantizedLstmParams.hpp>
12 
13 #include <cctype>
14 #include <functional>
15 #include <memory>
16 #include <vector>
17 
18 namespace armnn
19 {
20 
21 class TensorInfo;
22 
23 class ILayerSupport
24 {
25 protected:
ILayerSupport()26     ILayerSupport() {}
~ILayerSupport()27     virtual ~ILayerSupport() {}
28 
29 public:
30     ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead")
31     virtual bool IsAbsSupported(const TensorInfo& input,
32                                 const TensorInfo& output,
33                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
34 
35     virtual bool IsActivationSupported(const TensorInfo& input,
36                                        const TensorInfo& output,
37                                        const ActivationDescriptor& descriptor,
38                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
39 
40     virtual bool IsAdditionSupported(const TensorInfo& input0,
41                                      const TensorInfo& input1,
42                                      const TensorInfo& output,
43                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
44 
45     virtual bool IsArgMinMaxSupported(const TensorInfo& input,
46                                       const TensorInfo& output,
47                                       const ArgMinMaxDescriptor& descriptor,
48                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
49 
50     virtual bool IsBatchNormalizationSupported(const TensorInfo& input,
51                                                const TensorInfo& output,
52                                                const TensorInfo& mean,
53                                                const TensorInfo& var,
54                                                const TensorInfo& beta,
55                                                const TensorInfo& gamma,
56                                                const BatchNormalizationDescriptor& descriptor,
57                                                Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
58 
59     virtual bool IsBatchToSpaceNdSupported(const TensorInfo& input,
60                                            const TensorInfo& output,
61                                            const BatchToSpaceNdDescriptor& descriptor,
62                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
63 
64     virtual bool IsComparisonSupported(const TensorInfo& input0,
65                                        const TensorInfo& input1,
66                                        const TensorInfo& output,
67                                        const ComparisonDescriptor& descriptor,
68                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
69 
70     virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
71                                    const TensorInfo& output,
72                                    const OriginsDescriptor& descriptor,
73                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
74 
75     virtual bool IsConstantSupported(const TensorInfo& output,
76                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
77 
78     virtual bool IsConvertBf16ToFp32Supported(const TensorInfo& input,
79                                               const TensorInfo& output,
80                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
81 
82     virtual bool IsConvertFp32ToBf16Supported(const TensorInfo& input,
83                                               const TensorInfo& output,
84                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
85 
86     virtual bool IsConvertFp16ToFp32Supported(const TensorInfo& input,
87                                               const TensorInfo& output,
88                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
89 
90     virtual bool IsConvertFp32ToFp16Supported(const TensorInfo& input,
91                                               const TensorInfo& output,
92                                               Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
93 
94     virtual bool IsConvolution2dSupported(const TensorInfo& input,
95                                           const TensorInfo& output,
96                                           const Convolution2dDescriptor& descriptor,
97                                           const TensorInfo& weights,
98                                           const Optional<TensorInfo>& biases,
99                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
100 
101     virtual bool IsDebugSupported(const TensorInfo& input,
102                                   const TensorInfo& output,
103                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
104 
105     virtual bool IsDepthToSpaceSupported(const TensorInfo& input,
106                                          const TensorInfo& output,
107                                          const DepthToSpaceDescriptor& descriptor,
108                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
109 
110     virtual bool IsDepthwiseConvolutionSupported(
111                      const TensorInfo& input,
112                      const TensorInfo& output,
113                      const DepthwiseConvolution2dDescriptor& descriptor,
114                      const TensorInfo& weights,
115                      const Optional<TensorInfo>& biases,
116                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
117 
118     virtual bool IsDequantizeSupported(const TensorInfo& input,
119                                        const TensorInfo& output,
120                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
121 
122     virtual bool IsDetectionPostProcessSupported(const TensorInfo& boxEncodings,
123                                                  const TensorInfo& scores,
124                                                  const TensorInfo& anchors,
125                                                  const TensorInfo& detectionBoxes,
126                                                  const TensorInfo& detectionClasses,
127                                                  const TensorInfo& detectionScores,
128                                                  const TensorInfo& numDetections,
129                                                  const DetectionPostProcessDescriptor& descriptor,
130                                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const =0;
131 
132     virtual bool IsDilatedDepthwiseConvolutionSupported(
133                     const TensorInfo& input,
134                     const TensorInfo& output,
135                     const DepthwiseConvolution2dDescriptor& descriptor,
136                     const TensorInfo& weights,
137                     const Optional<TensorInfo>& biases,
138                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
139 
140     virtual bool IsDivisionSupported(const TensorInfo& input0,
141                                      const TensorInfo& input1,
142                                      const TensorInfo& output,
143                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
144 
145     virtual bool IsElementwiseUnarySupported(const TensorInfo& input,
146                                              const TensorInfo& output,
147                                              const ElementwiseUnaryDescriptor& descriptor,
148                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
149 
150     ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
151     virtual bool IsEqualSupported(const TensorInfo& input0,
152                                   const TensorInfo& input1,
153                                   const TensorInfo& output,
154                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
155 
156     virtual bool IsFakeQuantizationSupported(const TensorInfo& input,
157                                              const FakeQuantizationDescriptor& descriptor,
158                                              Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
159 
160     virtual bool IsFillSupported(const TensorInfo& input,
161                                  const TensorInfo& output,
162                                  const FillDescriptor& descriptor,
163                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
164 
165     virtual bool IsFloorSupported(const TensorInfo& input,
166                                   const TensorInfo& output,
167                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
168 
169     virtual bool IsFullyConnectedSupported(const TensorInfo& input,
170                                            const TensorInfo& output,
171                                            const TensorInfo& weights,
172                                            const TensorInfo& biases,
173                                            const FullyConnectedDescriptor& descriptor,
174                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
175 
176     ARMNN_DEPRECATED_MSG("Use IsGatherSupported with descriptor instead")
177     virtual bool IsGatherSupported(const TensorInfo& input0,
178                                    const TensorInfo& input1,
179                                    const TensorInfo& output,
180                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
181 
182     virtual bool IsGatherSupported(const TensorInfo& input0,
183                                    const TensorInfo& input1,
184                                    const TensorInfo& output,
185                                    const GatherDescriptor& descriptor,
186                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
187 
188     ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
189     virtual bool IsGreaterSupported(const TensorInfo& input0,
190                                     const TensorInfo& input1,
191                                     const TensorInfo& ouput,
192                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
193 
194     virtual bool IsInputSupported(const TensorInfo& input,
195                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
196 
197     virtual bool IsInstanceNormalizationSupported(
198         const TensorInfo& input,
199         const TensorInfo& output,
200         const InstanceNormalizationDescriptor& descriptor,
201         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
202 
203     virtual bool IsL2NormalizationSupported(const TensorInfo& input,
204                                             const TensorInfo& output,
205                                             const L2NormalizationDescriptor& descriptor,
206                                             Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
207 
208     virtual bool IsLogicalBinarySupported(const TensorInfo& input0,
209                                           const TensorInfo& input1,
210                                           const TensorInfo& output,
211                                           const LogicalBinaryDescriptor& descriptor,
212                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
213 
214     virtual bool IsLogicalUnarySupported(const TensorInfo& input,
215                                          const TensorInfo& output,
216                                          const ElementwiseUnaryDescriptor& descriptor,
217                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
218 
219     virtual bool IsLogSoftmaxSupported(const TensorInfo& input,
220                                        const TensorInfo& output,
221                                        const LogSoftmaxDescriptor& descriptor,
222                                        Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
223 
224     virtual bool IsLstmSupported(const TensorInfo& input,
225                                  const TensorInfo& outputStateIn,
226                                  const TensorInfo& cellStateIn,
227                                  const TensorInfo& scratchBuffer,
228                                  const TensorInfo& outputStateOut,
229                                  const TensorInfo& cellStateOut,
230                                  const TensorInfo& output,
231                                  const LstmDescriptor& descriptor,
232                                  const LstmInputParamsInfo& paramsInfo,
233                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
234 
235     virtual bool IsMaximumSupported(const TensorInfo& input0,
236                                     const TensorInfo& input1,
237                                     const TensorInfo& output,
238                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
239 
240     virtual bool IsMeanSupported(const TensorInfo& input,
241                                  const TensorInfo& output,
242                                  const MeanDescriptor& descriptor,
243                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
244 
245     virtual bool IsMemCopySupported(const TensorInfo& input,
246                                     const TensorInfo& output,
247                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
248 
249     virtual bool IsMemImportSupported(const TensorInfo& input,
250                                       const TensorInfo& output,
251                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
252 
253     virtual bool IsMergeSupported(const TensorInfo& input0,
254                                   const TensorInfo& input1,
255                                   const TensorInfo& output,
256                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
257 
258     ARMNN_DEPRECATED_MSG("Use IsConcatSupported instead")
259     virtual bool IsMergerSupported(const std::vector<const TensorInfo*> inputs,
260                                    const TensorInfo& output,
261                                    const OriginsDescriptor& descriptor,
262                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
263 
264     virtual bool IsMinimumSupported(const TensorInfo& input0,
265                                     const TensorInfo& input1,
266                                     const TensorInfo& ouput,
267                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
268 
269     virtual bool IsMultiplicationSupported(const TensorInfo& input0,
270                                            const TensorInfo& input1,
271                                            const TensorInfo& output,
272                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
273 
274     virtual bool IsNormalizationSupported(const TensorInfo& input,
275                                           const TensorInfo& output,
276                                           const NormalizationDescriptor& descriptor,
277                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
278 
279     virtual bool IsOutputSupported(const TensorInfo& output,
280                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
281 
282     virtual bool IsPadSupported(const TensorInfo& input,
283                                 const TensorInfo& output,
284                                 const PadDescriptor& descriptor,
285                                 Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
286 
287     virtual bool IsPermuteSupported(const TensorInfo& input,
288                                     const TensorInfo& output,
289                                     const PermuteDescriptor& descriptor,
290                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
291 
292     virtual bool IsPooling2dSupported(const TensorInfo& input,
293                                       const TensorInfo& output,
294                                       const Pooling2dDescriptor& descriptor,
295                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
296 
297     virtual bool IsPreCompiledSupported(const TensorInfo& input,
298                                         const PreCompiledDescriptor& descriptor,
299                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
300 
301     virtual bool IsPreluSupported(const TensorInfo& input,
302                                   const TensorInfo& alpha,
303                                   const TensorInfo& output,
304                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
305 
306     virtual bool IsQuantizeSupported(const TensorInfo& input,
307                                      const TensorInfo& output,
308                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
309 
310     virtual bool IsQLstmSupported(const TensorInfo& input,
311                                   const TensorInfo& previousOutputIn,
312                                   const TensorInfo& previousCellStateIn,
313                                   const TensorInfo& outputStateOut,
314                                   const TensorInfo& cellStateOut,
315                                   const TensorInfo& output,
316                                   const QLstmDescriptor& descriptor,
317                                   const LstmInputParamsInfo& paramsInfo,
318                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
319 
320     virtual bool IsQuantizedLstmSupported(const TensorInfo& input,
321                                           const TensorInfo& previousCellStateIn,
322                                           const TensorInfo& previousOutputIn,
323                                           const TensorInfo& cellStateOut,
324                                           const TensorInfo& output,
325                                           const QuantizedLstmInputParamsInfo& paramsInfo,
326                                           Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
327 
328     virtual bool IsRankSupported(const TensorInfo& input,
329                                  const TensorInfo& output,
330                                  Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
331 
332     virtual bool IsReshapeSupported(const TensorInfo& input,
333                                     const TensorInfo& output,
334                                     const ReshapeDescriptor& descriptor,
335                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
336 
337     ARMNN_DEPRECATED_MSG("Use IsResizeSupported instead")
338     virtual bool IsResizeBilinearSupported(const TensorInfo& input,
339                                            const TensorInfo& output,
340                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
341 
342     virtual bool IsResizeSupported(const TensorInfo& input,
343                                    const TensorInfo& output,
344                                    const ResizeDescriptor& descriptor,
345                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
346 
347     ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead")
348     virtual bool IsRsqrtSupported(const TensorInfo& input,
349                                   const TensorInfo& output,
350                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
351 
352     virtual bool IsSliceSupported(const TensorInfo& input,
353                                   const TensorInfo& output,
354                                   const SliceDescriptor& descriptor,
355                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
356 
357     virtual bool IsSoftmaxSupported(const TensorInfo& input,
358                                     const TensorInfo& output,
359                                     const SoftmaxDescriptor& descriptor,
360                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
361 
362     virtual bool IsSpaceToBatchNdSupported(const TensorInfo& input,
363                                            const TensorInfo& output,
364                                            const SpaceToBatchNdDescriptor& descriptor,
365                                            Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
366 
367     virtual bool IsSpaceToDepthSupported(const TensorInfo& input,
368                                          const TensorInfo& output,
369                                          const SpaceToDepthDescriptor& descriptor,
370                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
371 
372     ARMNN_DEPRECATED_MSG("Use IsSplitterSupported with outputs instead")
373     virtual bool IsSplitterSupported(const TensorInfo& input,
374                                      const ViewsDescriptor& descriptor,
375                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
376 
377     virtual bool IsSplitterSupported(const TensorInfo& input,
378                                      const std::vector<std::reference_wrapper<TensorInfo>>& outputs,
379                                      const ViewsDescriptor& descriptor,
380                                      Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
381 
382     virtual bool IsStackSupported(const std::vector<const TensorInfo*>& inputs,
383                                   const TensorInfo& output,
384                                   const StackDescriptor& descriptor,
385                                   Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
386 
387     virtual bool IsStandInSupported(const std::vector<const TensorInfo*>& inputs,
388                                     const std::vector<const TensorInfo*>& outputs,
389                                     const StandInDescriptor& descriptor,
390                                     Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
391 
392 
393     virtual bool IsStridedSliceSupported(const TensorInfo& input,
394                                          const TensorInfo& output,
395                                          const StridedSliceDescriptor& descriptor,
396                                          Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
397 
398     virtual bool IsSubtractionSupported(const TensorInfo& input0,
399                                         const TensorInfo& input1,
400                                         const TensorInfo& output,
401                                         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
402 
403     virtual bool IsSwitchSupported(const TensorInfo& input0,
404                                    const TensorInfo& input1,
405                                    const TensorInfo& output0,
406                                    const TensorInfo& output1,
407                                    Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
408 
409     virtual bool IsTransposeConvolution2dSupported(
410         const TensorInfo& input,
411         const TensorInfo& output,
412         const TransposeConvolution2dDescriptor& descriptor,
413         const TensorInfo& weights,
414         const Optional<TensorInfo>& biases,
415         Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
416 
417     virtual bool IsTransposeSupported(const TensorInfo& input,
418                                       const TensorInfo& output,
419                                       const TransposeDescriptor& descriptor,
420                                       Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
421 
422 }; // class ILayerSupport
423 
424 using ILayerSupportSharedPtr = std::shared_ptr<ILayerSupport>;
425 
426 } // namespace armnn
427