• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "DelegateUtils.hpp"
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
VisitPoolingOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLitePoolingOperatorCode)18 TfLiteStatus VisitPoolingOperator(DelegateData& delegateData,
19                                   TfLiteContext* tfLiteContext,
20                                   TfLiteNode* tfLiteNode,
21                                   int nodeIndex,
22                                   int32_t tfLitePoolingOperatorCode)
23 {
24     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 
27     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29     if (IsDynamicTensor(tfLiteInputTensor))
30     {
31         TF_LITE_MAYBE_KERNEL_LOG(
32             tfLiteContext,
33             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
34             tfLitePoolingOperatorCode, nodeIndex);
35         return kTfLiteError;
36     }
37 
38     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
39     if (IsDynamicTensor(tfLiteOutputTensor))
40     {
41         TF_LITE_MAYBE_KERNEL_LOG(
42             tfLiteContext,
43             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
44             tfLitePoolingOperatorCode, nodeIndex);
45         return kTfLiteError;
46     }
47 
48     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
49     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
50 
51     armnn::PoolingAlgorithm poolingAlgorithm;
52     switch(tfLitePoolingOperatorCode)
53     {
54         case kTfLiteBuiltinAveragePool2d:
55             poolingAlgorithm = armnn::PoolingAlgorithm::Average;
56             break;
57         case kTfLiteBuiltinL2Pool2d:
58             poolingAlgorithm = armnn::PoolingAlgorithm::L2;
59             break;
60         case kTfLiteBuiltinMaxPool2d:
61             poolingAlgorithm = armnn::PoolingAlgorithm::Max;
62             break;
63         default:
64             return kTfLiteError;
65     }
66 
67     armnn::Pooling2dDescriptor descriptor;
68     descriptor.m_PoolType = poolingAlgorithm;
69 
70     auto* params = reinterpret_cast<TfLitePoolParams*>(tfLiteNode->builtin_data);
71     descriptor.m_PoolWidth = params->filter_width;
72     descriptor.m_PoolHeight = params->filter_height;
73     descriptor.m_StrideX = params->stride_width;
74     descriptor.m_StrideY = params->stride_height;
75     descriptor.m_DataLayout = armnn::DataLayout::NHWC;
76 
77     unsigned int inputHeight = inputTensorInfo.GetShape()[1];
78     unsigned int inputWidth  = inputTensorInfo.GetShape()[2];
79 
80     CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
81                 descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
82     CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
83                 descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
84 
85     bool isSupported = false;
86     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
87     {
88         FORWARD_LAYER_SUPPORT_FUNC(__func__,
89                                    tfLiteContext,
90                                    IsPooling2dSupported,
91                                    delegateData.m_Backends,
92                                    isSupported,
93                                    inputTensorInfo,
94                                    outputTensorInfo,
95                                    descriptor);
96     };
97 
98     if (!delegateData.m_Network)
99     {
100         validateFunc(outputTensorInfo, isSupported);
101         return isSupported ? kTfLiteOk : kTfLiteError;
102     }
103 
104     armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling2dLayer(descriptor);
105     ARMNN_ASSERT(poolingLayer != nullptr);
106 
107     armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
108     outputSlot.SetTensorInfo(outputTensorInfo);
109     Connect(poolingLayer, tfLiteNode, delegateData);
110 
111     // Check activation
112     TfLiteFusedActivation activationType = params->activation;
113     return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
114 }
115 
116 } // namespace armnnDelegate
117