• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <armnn/utility/IgnoreUnused.hpp>
11 
12 #include <tensorflow/lite/builtin_ops.h>
13 #include <tensorflow/lite/c/builtin_op_data.h>
14 #include <tensorflow/lite/c/common.h>
15 #include <tensorflow/lite/minimal_logging.h>
16 
17 namespace armnnDelegate
18 {
19 
VisitComparisonOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t tfLiteComparisonOperatorCode)20 TfLiteStatus VisitComparisonOperator(DelegateData& delegateData,
21                                      TfLiteContext* tfLiteContext,
22                                      TfLiteNode* tfLiteNode,
23                                      int nodeIndex,
24                                      int32_t tfLiteComparisonOperatorCode)
25 {
26     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
27     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
28 
29     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
30     const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
31     if (IsDynamicTensor(tfLiteInputTensor0))
32     {
33         TF_LITE_MAYBE_KERNEL_LOG(
34             tfLiteContext,
35             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
36             tfLiteComparisonOperatorCode, nodeIndex);
37 
38         return kTfLiteError;
39     }
40 
41     const TfLiteTensor& tfLiteInputTensor1 = tfLiteTensors[tfLiteNode->inputs->data[1]];
42     if (IsDynamicTensor(tfLiteInputTensor1))
43     {
44         TF_LITE_MAYBE_KERNEL_LOG(
45             tfLiteContext,
46             "TfLiteArmnnDelegate: Dynamic input tensors are not supported in operator #%d node #%d: ",
47             tfLiteComparisonOperatorCode, nodeIndex);
48         return kTfLiteError;
49     }
50 
51     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
52     if (IsDynamicTensor(tfLiteOutputTensor))
53     {
54         TF_LITE_MAYBE_KERNEL_LOG(
55             tfLiteContext,
56             "TfLiteArmnnDelegate: Dynamic output tensors are not supported in operator #%d node #%d: ",
57             tfLiteComparisonOperatorCode, nodeIndex);
58         return kTfLiteError;
59     }
60 
61     armnn::TensorInfo inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
62     armnn::TensorInfo inputTensorInfo1 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor1);
63     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
64 
65     // Check if we need to expand the dims of any of the input tensor infos.
66     // This is required for a few of the backends.
67     if(inputTensorInfo0.GetNumDimensions() != inputTensorInfo1.GetNumDimensions())
68     {
69         ExpandTensorRankToEqual(inputTensorInfo0, inputTensorInfo1);
70     }
71 
72     armnn::ComparisonOperation comparisonOperation = armnn::ComparisonOperation::Equal;
73     switch(tfLiteComparisonOperatorCode)
74     {
75         case kTfLiteBuiltinEqual:
76             comparisonOperation = armnn::ComparisonOperation::Equal;
77             break;
78         case kTfLiteBuiltinGreater:
79             comparisonOperation = armnn::ComparisonOperation::Greater;
80             break;
81         case kTfLiteBuiltinGreaterEqual:
82             comparisonOperation = armnn::ComparisonOperation::GreaterOrEqual;
83             break;
84         case kTfLiteBuiltinLess:
85             comparisonOperation = armnn::ComparisonOperation::Less;
86             break;
87         case kTfLiteBuiltinLessEqual:
88             comparisonOperation = armnn::ComparisonOperation::LessOrEqual;
89             break;
90         case kTfLiteBuiltinNotEqual:
91             comparisonOperation = armnn::ComparisonOperation::NotEqual;
92             break;
93         default:
94             return kTfLiteError;
95     }
96 
97     armnn::ComparisonDescriptor descriptor(comparisonOperation);
98     bool isSupported = false;
99     armnn::BackendId setBackend;
100     auto validateFunc = [&](const armnn::TensorInfo& outputTensorInfo, bool& isSupported)
101     {
102         FORWARD_LAYER_SUPPORT_FUNC("COMPARISON",
103                                    tfLiteContext,
104                                    IsComparisonSupported,
105                                    delegateData.m_Backends,
106                                    isSupported,
107                                    setBackend,
108                                    inputTensorInfo0,
109                                    inputTensorInfo1,
110                                    outputTensorInfo,
111                                    descriptor);
112     };
113 
114     if (!delegateData.m_Network)
115     {
116         validateFunc(outputTensorInfo, isSupported);
117         return isSupported ? kTfLiteOk : kTfLiteError;
118     }
119 
120     armnn::IConnectableLayer* comparisonLayer = delegateData.m_Network->AddComparisonLayer(descriptor);
121     comparisonLayer->SetBackendId(setBackend);
122     ARMNN_ASSERT(comparisonLayer != nullptr);
123 
124     armnn::IOutputSlot& outputSlot = comparisonLayer->GetOutputSlot(0);
125     outputSlot.SetTensorInfo(outputTensorInfo);
126 
127     // try to connect the Constant Inputs if there are any
128     if(ProcessInputs(comparisonLayer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
129     {
130         return kTfLiteError;
131     }
132 
133     return Connect(comparisonLayer, tfLiteNode, delegateData);
134 }
135 
136 } // namespace armnnDelegate
137