1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <ClassicDelegateUtils.hpp>
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14
15 namespace armnnDelegate
16 {
17
ValidatePreluOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & alphaInfo,const armnn::TensorInfo & outputInfo)18 TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
19 TfLiteContext* tfLiteContext,
20 const armnn::TensorInfo& inputInfo,
21 const armnn::TensorInfo& alphaInfo,
22 const armnn::TensorInfo& outputInfo)
23 {
24 bool isSupported = false;
25 auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26 {
27 FORWARD_LAYER_SUPPORT_FUNC("PRELU",
28 tfLiteContext,
29 IsPreluSupported,
30 delegateData.m_Backends,
31 isSupported,
32 armnn::BackendId(),
33 inputInfo,
34 alphaInfo,
35 outputInfo);
36 };
37
38 validateFunc(outputInfo, isSupported);
39 return isSupported ? kTfLiteOk : kTfLiteError;
40 }
41
VisitPreluOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)42 TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
43 TfLiteContext* tfLiteContext,
44 TfLiteNode* tfLiteNode,
45 int nodeIndex,
46 int32_t operatorCode)
47 {
48 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
49 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50
51 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52
53 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
54 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
55 {
56 return kTfLiteError;
57 }
58
59 const TfLiteTensor& tfLiteAlphaTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
60 if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
61 {
62 return kTfLiteError;
63 }
64
65 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
66 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
67 {
68 return kTfLiteError;
69 }
70
71 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
72 const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAlphaTensor);
73 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
74
75 if (!delegateData.m_Network)
76 {
77 return ValidatePreluOperator(delegateData,
78 tfLiteContext,
79 inputTensorInfo,
80 alphaTensorInfo,
81 outputTensorInfo);
82 }
83
84 armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer();
85 ARMNN_ASSERT(preluLayer != nullptr);
86
87 bool isConstantAlpha = tflite::IsConstantTensor(&tfLiteAlphaTensor);
88
89 // Add constant layer for constant alpha
90 if (isConstantAlpha)
91 {
92 auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, tfLiteAlphaTensor.data.data);
93
94 armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor);
95 ARMNN_ASSERT(constLayer != nullptr);
96
97 constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
98 constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
99 }
100
101 armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
102 outputSlot.SetTensorInfo(outputTensorInfo);
103
104 // Connect
105 return Connect(preluLayer, tfLiteNode, delegateData);
106 }
107
108 } // namespace armnnDelegate