• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ClassicDelegateUtils.hpp>
9 
10 #include <tensorflow/lite/builtin_ops.h>
11 #include <tensorflow/lite/c/builtin_op_data.h>
12 #include <tensorflow/lite/c/common.h>
13 #include <tensorflow/lite/minimal_logging.h>
14 #include <numeric>
15 
16 namespace armnnDelegate
17 {
18 
VisitShapeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)19 TfLiteStatus VisitShapeOperator(DelegateData& delegateData,
20                                TfLiteContext* tfLiteContext,
21                                TfLiteNode* tfLiteNode,
22                                int nodeIndex,
23                                int32_t operatorCode)
24 {
25     TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
26     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
27 
28     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
29     const TfLiteTensor& tfLiteInputTensor = tfLiteTensors[tfLiteNode->inputs->data[0]];
30     if (!IsValid(tfLiteContext, tfLiteInputTensor, operatorCode, nodeIndex))
31     {
32         return kTfLiteError;
33     }
34 
35     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
36     if (!IsValid(tfLiteContext, tfLiteOutputTensor, operatorCode, nodeIndex))
37     {
38         return kTfLiteError;
39     }
40 
41     const armnn::TensorInfo& inputTensorInfo  = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
42     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor, true);
43 
44     auto* shapeParameters = reinterpret_cast<TfLiteShapeParams*>(tfLiteNode->builtin_data);
45     if ( shapeParameters->out_type != kTfLiteInt32 && shapeParameters->out_type != kTfLiteInt64 )
46     {
47         TF_LITE_MAYBE_KERNEL_LOG(
48             tfLiteContext,
49             "TfLiteArmnnDelegate: output_type data type is not supported in operator #%d node #%d: ",
50             operatorCode, nodeIndex);
51         return kTfLiteError;
52     }
53 
54     bool isSupported = false;
55     armnn::BackendId setBackend;
56     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
57     {
58         FORWARD_LAYER_SUPPORT_FUNC("SHAPE",
59                                    tfLiteContext,
60                                    IsShapeSupported,
61                                    delegateData.m_Backends,
62                                    isSupported,
63                                    setBackend,
64                                    inputTensorInfo,
65                                    outInfo);
66     };
67 
68     // If the m_Network is a nullptr, this signals that a prerequisite TfLite callback is required to clarify the
69     // support for the operator
70     // If supported, VisitShapeOperator will be called again to add the layer to the network as seen further below
71     if (!delegateData.m_Network)
72     {
73         validateFunc(outputTensorInfo, isSupported);
74         return isSupported ? kTfLiteOk : kTfLiteError;
75     }
76 
77     // Add a Shape layer
78     armnn::IConnectableLayer* layer = delegateData.m_Network->AddShapeLayer();
79     layer->SetBackendId(setBackend);
80     ARMNN_ASSERT(layer != nullptr);
81 
82     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
83     outputSlot.SetTensorInfo(outputTensorInfo);
84 
85     // try to connect the Constant Inputs if there are any
86     if(ProcessInputs(layer,delegateData, tfLiteContext, tfLiteNode) != kTfLiteOk )
87     {
88         return kTfLiteError;
89     }
90 
91     // Connect
92     return Connect(layer, tfLiteNode, delegateData);
93 }
94 
95 } // namespace armnnDelegate
96