1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <ClassicDelegateUtils.hpp>
9
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 #include <numeric>
15
16 namespace armnnDelegate
17 {
18
VisitShapeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)19 TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
20 TfLiteContext* tfLiteContext,
21 TfLiteNode* tfLiteNode,
22 int nodeIndex,
23 int32_t operatorCode)
24 {
25 TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26 TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27
28 const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29 const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30 if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
31 {
32 return kTfLiteError;
33 }
34
35 const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
36 if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
37 {
38 return kTfLiteError;
39 }
40
41 const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
42 const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
43
44 auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(tfLiteNode->builtin_data);
45 if ( shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64 )
46 {
47 TF_LITE_MAYBE_KERNEL_LOG(
48 tfLiteContext,
49 "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
50 operatorCode, nodeIndex);
51 return kTfLiteError;
52 }
53
54 bool isSupported = false;
55 armnn::BackendId setBackend;
56 auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
57 {
58 FORWARD_LAYER_SUPPORT_FUNC("SHAPE",
59 tfLiteContext,
60 IsShapeSupported,
61 delegateData.m_Backends,
62 isSupported,
63 setBackend,
64 inputTensorInfo,
65 outInfo);
66 };
67
68 // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
69 // support for the operator
70 // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
71 if (!delegateData.m_Network)
72 {
73 validateFunc(outputTensorInfo, isSupported);
74 return isSupported ? kTfLiteOk : kTfLiteError;
75 }
76
77 // Add a Shape layer
78 armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer();
79 layer->SetBackendId(setBackend);
80 ARMNN_ASSERT(layer != nullptr);
81
82 armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
83 outputSlot.SetTensorInfo(outputTensorInfo);
84
85 // try to connect the Constant Inputs if there are any
86 if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
87 {
88 return kTfLiteError;
89 }
90
91 // Connect
92 return Connect(layer, tfLiteNode, delegateData);
93 }
94
95 } // namespace armnnDelegate
96