• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 
15 namespace armnnDelegate
16 {
17 
ValidatePreluOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,const armnn::TensorInfo & inputInfo,const armnn::TensorInfo & alphaInfo,const armnn::TensorInfo & outputInfo)18 TfLiteStatus ValidatePreluOperator(DelegateData& delegateData,
19                                    TfLiteContext* tfLiteContext,
20                                    const armnn::TensorInfo& inputInfo,
21                                    const armnn::TensorInfo& alphaInfo,
22                                    const armnn::TensorInfo& outputInfo)
23 {
24     bool isSupported = false;
25     auto validateFunc = [&](const armnn::TensorInfo& outputInfo, bool& isSupported)
26     {
27         FORWARD_LAYER_SUPPORT_FUNC("PRELU",
28                                    tfLiteContext,
29                                    IsPreluSupported,
30                                    delegateData.m_Backends,
31                                    isSupported,
32                                    armnn::BackendId(),
33                                    inputInfo,
34                                    alphaInfo,
35                                    outputInfo);
36     };
37 
38     validateFunc(outputInfo, isSupported);
39     return isSupported ? kTfLiteOk : kTfLiteError;
40 }
41 
VisitPreluOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)42 TfLiteStatus VisitPreluOperator(DelegateData& delegateData,
43                                 TfLiteContext* tfLiteContext,
44                                 TfLiteNode* tfLiteNode,
45                                 int nodeIndex,
46                                 int32_t operatorCode)
47 {
48     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
49     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
50 
51     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
52 
53     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
54     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
55     {
56         return kTfLiteError;
57     }
58 
59     const TfLiteTensor& tfLiteAlphaTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
60     if (!IsValid(tfLiteContext, tfLiteAlphaTensor, operatorCode, nodeIndex))
61     {
62         return kTfLiteError;
63     }
64 
65     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
66     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
67     {
68         return kTfLiteError;
69     }
70 
71     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
72     const armnn::TensorInfo& alphaTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteAlphaTensor);
73     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
74 
75     if (!delegateData.m_Network)
76     {
77         return ValidatePreluOperator(delegateData,
78                                      tfLiteContext,
79                                      inputTensorInfo,
80                                      alphaTensorInfo,
81                                      outputTensorInfo);
82     }
83 
84     armnn::IConnectableLayer* preluLayer = delegateData.m_Network->AddPreluLayer();
85     ARMNN_ASSERT(preluLayer != nullptr);
86 
87     bool isConstantAlpha = tflite::IsConstantTensor(&tfLiteAlphaTensor);
88 
89     // Add constant layer for constant alpha
90     if (isConstantAlpha)
91     {
92         auto constAlphaTensor = armnn::ConstTensor(alphaTensorInfo, tfLiteAlphaTensor.data.data);
93 
94         armnn::IConnectableLayer* constLayer = delegateData.m_Network->AddConstantLayer(constAlphaTensor);
95         ARMNN_ASSERT(constLayer != nullptr);
96 
97         constLayer->GetOutputSlot(0).SetTensorInfo(alphaTensorInfo);
98         constLayer->GetOutputSlot(0).Connect(preluLayer->GetInputSlot(1));
99     }
100 
101     armnn::IOutputSlot& outputSlot = preluLayer->GetOutputSlot(0);
102     outputSlot.SetTensorInfo(outputTensorInfo);
103 
104     // Connect
105     return Connect(preluLayer, tfLiteNode, delegateData);
106 }
107 
108 } // namespace armnnDelegate