• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <tensorflow/lite/builtin_ops.h>
9 #include <tensorflow/lite/c/builtin_op_data.h>
10 #include <tensorflow/lite/c/common.h>
11 #include <tensorflow/lite/kernels/internal/tensor_ctypes.h>
12 #include <tensorflow/lite/minimal_logging.h>
13 
14 namespace armnnDelegate
15 {
16 
VisitArgMinMaxOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t argMinMaxOperatorCode)17 TfLiteStatus VisitArgMinMaxOperator(DelegateData& delegateData,
18                                     TfLiteContext* tfLiteContext,
19                                     TfLiteNode* tfLiteNode,
20                                     int nodeIndex,
21                                     int32_t argMinMaxOperatorCode)
22 {
23     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
24     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
25 
26     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
27     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
28     if (!IsValid(tfLiteContext, tfLiteInputTensor, argMinMaxOperatorCode, nodeIndex))
29     {
30         return kTfLiteError;
31     }
32 
33     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
34     if (!IsValid(tfLiteContext, tfLiteOutputTensor, argMinMaxOperatorCode, nodeIndex))
35     {
36         return kTfLiteError;
37     }
38 
39     const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
40     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
41 
42     // Get const axis value from model and set it to descriptor.
43     const TfLiteTensor& tfLiteAxisTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
44     if (!IsValid(tfLiteContext, tfLiteAxisTensor, argMinMaxOperatorCode, nodeIndex))
45     {
46         return kTfLiteError;
47     }
48 
49     armnn::ArgMinMaxDescriptor desc;
50     // Get the axis value from the input tensor
51     switch (tfLiteAxisTensor.type)
52     {
53         case kTfLiteInt32:
54         case kTfLiteInt64:
55             desc.m_Axis = tflite::GetTensorData<int>(&tfLiteAxisTensor)[0];
56             break;
57         default:
58             TF_LITE_MAYBE_KERNEL_LOG(
59                 tfLiteContext,
60                 "TfLiteArmnnDelegate: Axis value data type is not supported in operator #%d node #%d: ",
61                 argMinMaxOperatorCode, nodeIndex);
62             return kTfLiteError;
63     }
64 
65     // If output_type is int32 then set Signed32 else Signed64. Default type is Signed64.
66     if (argMinMaxOperatorCode == kTfLiteBuiltinArgMax)
67     {
68         desc.m_Function = armnn::ArgMinMaxFunction::Max;
69         auto* argMaxParameters = reinterpret_cast<TfLiteArgMaxParams*>(tfLiteNode->builtin_data);
70         if (argMaxParameters->output_type != kTfLiteInt32 && argMaxParameters->output_type != kTfLiteInt64)
71         {
72             TF_LITE_MAYBE_KERNEL_LOG(
73                 tfLiteContext,
74                 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
75                 argMinMaxOperatorCode, nodeIndex);
76             return kTfLiteError;
77         }
78     }
79     else
80     {
81         desc.m_Function = armnn::ArgMinMaxFunction::Min;
82         auto* argMinParameters = reinterpret_cast<TfLiteArgMinParams*>(tfLiteNode->builtin_data);
83         if (argMinParameters->output_type != kTfLiteInt32 && argMinParameters->output_type != kTfLiteInt64)
84         {
85             TF_LITE_MAYBE_KERNEL_LOG(
86                     tfLiteContext,
87                     "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
88                     argMinMaxOperatorCode, nodeIndex);
89             return kTfLiteError;
90         }
91     }
92 
93     bool isSupported = false;
94     armnn::BackendId setBackend;
95     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
96     {
97         FORWARD_LAYER_SUPPORT_FUNC("ARGMINMAX",
98                                    tfLiteContext,
99                                    IsArgMinMaxSupported,
100                                    delegateData.m_Backends,
101                                    isSupported,
102                                    setBackend,
103                                    inputTensorInfo,
104                                    outInfo,
105                                    desc);
106     };
107 
108     if (!delegateData.m_Network)
109     {
110         validateFunc(outputTensorInfo, isSupported);
111         return isSupported ? kTfLiteOk : kTfLiteError;
112     }
113 
114     // Add an ArgMinMax layer
115     armnn::IConnectableLayer* layer = delegateData.m_Network->AddArgMinMaxLayer(desc);
116     layer->SetBackendId(setBackend);
117     ARMNN_ASSERT(layer != nullptr);
118 
119     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
120     outputSlot.SetTensorInfo(outputTensorInfo);
121 
122     // try to connect the Constant Inputs if there are any
123     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
124     {
125         return kTfLiteError;
126     }
127 
128     // Connect
129     return Connect(layer, tfLiteNode, delegateData);
130 }
131 
132 } // namespace armnnDelegate
133