1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <ClassicDelegateUtils.hpp>
9
10 #include <armnn/utility/IgnoreUnused.hpp>
11
12 #include <tensorflow/lite/builtin_ops.h>
13 #include <tensorflow/lite/c/builtin_op_data.h>
14 #include <tensorflow/lite/c/common.h>
15 #include <tensorflow/lite/minimal_logging.h>
16
17 namespace armnnDelegate
18 {
19
VisitComparisonOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteComparisonOperatorCode)20 TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
21 TfLiteContext* tfLiteContext,
22 TfLiteNode* tfLiteNode,
23 int nodeIndex,
24 int32_t tfLiteComparisonOperatorCode)
25 {
26 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
27 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28
29 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30 const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
31 if (IsDynamicTensor(tfLiteInputTensor0))
32 {
33 TF_LITE_MAYBE_KERNEL_LOG(
34 tfLiteContext,
35 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
36 tfLiteComparisonOperatorCode, nodeIndex);
37
38 return kTfLiteError;
39 }
40
41 const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
42 if (IsDynamicTensor(tfLiteInputTensor1))
43 {
44 TF_LITE_MAYBE_KERNEL_LOG(
45 tfLiteContext,
46 "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
47 tfLiteComparisonOperatorCode, nodeIndex);
48 return kTfLiteError;
49 }
50
51 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
52 if (IsDynamicTensor(tfLiteOutputTensor))
53 {
54 TF_LITE_MAYBE_KERNEL_LOG(
55 tfLiteContext,
56 "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
57 tfLiteComparisonOperatorCode, nodeIndex);
58 return kTfLiteError;
59 }
60
61 armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
62 armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
63 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
64
65 // Check if we need to expand the dims of any of the input tensor infos.
66 // This is required for a few of the backends.
67 if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
68 {
69 ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
70 }
71
72 armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
73 switch(tfLiteComparisonOperatorCode)
74 {
75 case kTfLiteBuiltinEqual:
76 comparisonOperation = armnn::ComparisonOperation::Equal;
77 break;
78 case kTfLiteBuiltinGreater:
79 comparisonOperation = armnn::ComparisonOperation::Greater;
80 break;
81 case kTfLiteBuiltinGreaterEqual:
82 comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
83 break;
84 case kTfLiteBuiltinLess:
85 comparisonOperation = armnn::ComparisonOperation::Less;
86 break;
87 case kTfLiteBuiltinLessEqual:
88 comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
89 break;
90 case kTfLiteBuiltinNotEqual:
91 comparisonOperation = armnn::ComparisonOperation::NotEqual;
92 break;
93 default:
94 return kTfLiteError;
95 }
96
97 armnn::ComparisonDescriptor descriptor(comparisonOperation);
98 bool isSupported = false;
99 armnn::BackendId setBackend;
100 auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
101 {
102 FORWARD_LAYER_SUPPORT_FUNC("COMPARISON",
103 tfLiteContext,
104 IsComparisonSupported,
105 delegateData.m_Backends,
106 isSupported,
107 setBackend,
108 inputTensorInfo0,
109 inputTensorInfo1,
110 outputTensorInfo,
111 descriptor);
112 };
113
114 if (!delegateData.m_Network)
115 {
116 validateFunc(outputTensorInfo, isSupported);
117 return isSupported ? kTfLiteOk : kTfLiteError;
118 }
119
120 armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
121 comparisonLayer->SetBackendId(setBackend);
122 ARMNN_ASSERT(comparisonLayer != nullptr);
123
124 armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
125 outputSlot.SetTensorInfo(outputTensorInfo);
126
127 // try to connect the Constant Inputs if there are any
128 if(ProcessInputs(comparisonLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
129 {
130 return kTfLiteError;
131 }
132
133 return Connect(comparisonLayer, tfLiteNode, delegateData);
134 }
135
136 } // namespace armnnDelegate
137