1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "DelegateUtils.hpp"
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14
15 namespace armnnDelegate
16 {
17
VisitPoolingOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLitePoolingOperatorCode)18 TfLiteStatus VisitPoolingOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 TfLiteNode* tfLiteNode,
21 int nodeIndex,
22 int32_t tfLitePoolingOperatorCode)
23 {
24 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26
27 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
28 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
29 if (IsDynamicTensor(tfLiteInputTensor))
30 {
31 TF_LITE_MAYBE_KERNEL_LOG(
32 tfLiteContext,
33 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
34 tfLitePoolingOperatorCode, nodeIndex);
35 return kTfLiteError;
36 }
37
38 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
39 if (IsDynamicTensor(tfLiteOutputTensor))
40 {
41 TF_LITE_MAYBE_KERNEL_LOG(
42 tfLiteContext,
43 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
44 tfLitePoolingOperatorCode, nodeIndex);
45 return kTfLiteError;
46 }
47
48 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
49 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
50
51 armnn::PoolingAlgorithm poolingAlgorithm;
52 switch(tfLitePoolingOperatorCode)
53 {
54 case kTfLiteBuiltinAveragePool2d:
55 poolingAlgorithm = armnn::PoolingAlgorithm::Average;
56 break;
57 case kTfLiteBuiltinL2Pool2d:
58 poolingAlgorithm = armnn::PoolingAlgorithm::L2;
59 break;
60 case kTfLiteBuiltinMaxPool2d:
61 poolingAlgorithm = armnn::PoolingAlgorithm::Max;
62 break;
63 default:
64 return kTfLiteError;
65 }
66
67 armnn::Pooling2dDescriptor descriptor;
68 descriptor.m_PoolType = poolingAlgorithm;
69
70 auto* params = reinterpret_cast<TfLitePoolParams*>(tfLiteNode->builtin_data);
71 descriptor.m_PoolWidth = params->filter_width;
72 descriptor.m_PoolHeight = params->filter_height;
73 descriptor.m_StrideX = params->stride_width;
74 descriptor.m_StrideY = params->stride_height;
75 descriptor.m_DataLayout = armnn::DataLayout::NHWC;
76
77 unsigned int inputHeight = inputTensorInfo.GetShape()[1];
78 unsigned int inputWidth = inputTensorInfo.GetShape()[2];
79
80 CalcPadding(inputHeight, descriptor.m_PoolHeight, descriptor.m_StrideY, 1u,
81 descriptor.m_PadTop, descriptor.m_PadBottom, params->padding);
82 CalcPadding(inputWidth, descriptor.m_PoolWidth, descriptor.m_StrideX, 1u,
83 descriptor.m_PadLeft, descriptor.m_PadRight, params->padding);
84
85 bool isSupported = false;
86 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
87 {
88 FORWARD_LAYER_SUPPORT_FUNC(__func__,
89 tfLiteContext,
90 IsPooling2dSupported,
91 delegateData.m_Backends,
92 isSupported,
93 inputTensorInfo,
94 outputTensorInfo,
95 descriptor);
96 };
97
98 if (!delegateData.m_Network)
99 {
100 validateFunc(outputTensorInfo, isSupported);
101 return isSupported ? kTfLiteOk : kTfLiteError;
102 }
103
104 armnn::IConnectableLayer* poolingLayer = delegateData.m_Network->AddPooling2dLayer(descriptor);
105 ARMNN_ASSERT(poolingLayer != nullptr);
106
107 armnn::IOutputSlot& outputSlot = poolingLayer->GetOutputSlot(0);
108 outputSlot.SetTensorInfo(outputTensorInfo);
109 Connect(poolingLayer, tfLiteNode, delegateData);
110
111 // Check activation
112 TfLiteFusedActivation activationType = params->activation;
113 return FusedActivation(tfLiteContext, tfLiteNode, activationType, poolingLayer, 0, delegateData);
114 }
115
116 } // namespace armnnDelegate
117