• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/IgnoreUnused.hpp>
9 
10 #include "DelegateUtils.hpp"
11 
12 #include <tensorflow/lite/builtin_ops.h>
13 #include <tensorflow/lite/c/builtin_op_data.h>
14 #include <tensorflow/lite/c/common.h>
15 #include <tensorflow/lite/minimal_logging.h>
16 #include <numeric>
17 
18 namespace armnnDelegate
19 {
20 
CreateOutputTensorShape(const armnn::TensorInfo & inputTensorInfo,const std::vector<int32_t> & targetShape,armnn::ReshapeDescriptor & reshapeDesc)21 TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
22                                            const std::vector<int32_t>& targetShape,
23                                            armnn::ReshapeDescriptor& reshapeDesc)
24 {
25     std::vector<unsigned int> outputDims(targetShape.begin(), targetShape.end());
26     const auto stretchDim = std::find(targetShape.begin(), targetShape.end(), -1);
27 
28     if (stretchDim != targetShape.end())
29     {
30         if (std::find(std::next(stretchDim), targetShape.end(), -1) != targetShape.end())
31         {
32             // Return kTfLiteError and log the error after returning
33             return kTfLiteError;
34         }
35 
36         auto targetNumElements =
37             armnn::numeric_cast<unsigned int>(
38                 std::accumulate(targetShape.begin(), targetShape.end(), -1, std::multiplies<int32_t>()));
39 
40         auto stretchIndex = static_cast<size_t>(std::distance(targetShape.begin(), stretchDim));
41         outputDims[stretchIndex] = inputTensorInfo.GetNumElements() / targetNumElements;
42     }
43 
44     armnn::TensorShape outputShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()),
45                                                         outputDims.data());
46     reshapeDesc.m_TargetShape = outputShape;
47     return kTfLiteOk;
48 }
49 
VisitReshapeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)50 TfLiteStatus VisitReshapeOperator(DelegateData& delegateData,
51                                   TfLiteContext* tfLiteContext,
52                                   TfLiteNode* tfLiteNode,
53                                   int nodeIndex,
54                                   int32_t operatorCode)
55 {
56     auto numInputs = tfLiteNode->inputs->size;
57 
58     if (numInputs == 2)
59     {
60         TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 2, nodeIndex));
61     }
62     else
63     {
64         TF_LITE_ENSURE_STATUS(ValidateNumInputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
65     }
66     TF_LITE_ENSURE_STATUS(ValidateNumOutputs(tfLiteContext, tfLiteNode, 1, nodeIndex));
67 
68     const TfLiteTensor* tfLiteTensors = tfLiteContext->tensors;
69     const TfLiteTensor& tfLiteInputTensor0 = tfLiteTensors[tfLiteNode->inputs->data[0]];
70     if (IsDynamicTensor(tfLiteInputTensor0))
71     {
72         TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
73                                  "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
74                                  "operator #%d node #%d: ",
75                                  operatorCode, nodeIndex);
76         return kTfLiteError;
77     }
78 
79     const TfLiteTensor& tfLiteOutputTensor = tfLiteTensors[tfLiteNode->outputs->data[0]];
80     if (IsDynamicTensor(tfLiteOutputTensor))
81     {
82         TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
83                                  "TfLiteArmnnDelegate: Dynamic output tensors are not supported in "
84                                  "operator #%d node #%d: ",
85                                  operatorCode, nodeIndex);
86         return kTfLiteError;
87     }
88 
89     const armnn::TensorInfo& inputTensorInfo0 = GetTensorInfoForTfLiteTensor(tfLiteInputTensor0);
90     const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
91 
92     armnn::ReshapeDescriptor reshapeDesc;
93     std::vector<int32_t> targetShape;
94 
95     // The new shape can be defined by either a second input tensor or by a builtin option, we need to check for both.
96     if (numInputs == 2)
97     {
98         // Get shape from the second input tensor
99         const TfLiteTensor& tfLiteShapeInputTensor = tfLiteTensors[tfLiteNode->inputs->data[1]];
100         if (IsDynamicTensor(tfLiteShapeInputTensor))
101         {
102             TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
103                                      "TfLiteArmnnDelegate: Dynamic input tensors are not supported in "
104                                      "operator #%d node #%d: ",
105                                      operatorCode, nodeIndex);
106             return kTfLiteError;
107         }
108 
109         if (tfLiteShapeInputTensor.dims->size != 1)
110         {
111             TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
112                                      "TfLiteArmnnDelegate: Target 'shape' input is not a 1D tensor in "
113                                      "operator #%d node #%d: ",
114                                      operatorCode, nodeIndex);
115             return kTfLiteError;
116         }
117 
118         // Get the shape data out of the input tensor
119         auto* shapeTensorDataPtr = tflite::GetTensorData<int32_t>(&tfLiteShapeInputTensor);
120         auto shapeTensorNumValues = tfLiteShapeInputTensor.dims->data[0];
121         for (auto i=0; i < shapeTensorNumValues; ++i)
122         {
123             targetShape.push_back(*(shapeTensorDataPtr+i));
124         }
125     }
126     else
127     {
128         // Get shape from the builtin data
129         TfLiteReshapeParams* reshapeOptions = reinterpret_cast<TfLiteReshapeParams*>(tfLiteNode->builtin_data);
130 
131         if (reshapeOptions != nullptr)
132         {
133             // Options might be set without valid data. we need to check the dimensions are in a valid range.
134             if (reshapeOptions->num_dimensions > 0 && reshapeOptions->num_dimensions <= 8)
135             {
136                 for (int i=0; i < reshapeOptions->num_dimensions; ++i)
137                 {
138                     targetShape.push_back(reshapeOptions->shape[i]);
139                 }
140             }
141         }
142         else
143         {
144             TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
145                                      "Target shape not defined in reshape parameters or input tensor. "
146                                      "At least one method required in operator #%d node #%d: ",
147                                      operatorCode, nodeIndex);
148             return kTfLiteError;
149         }
150     }
151 
152     // Use the data to create the required tensor shape.
153     if (CreateOutputTensorShape(inputTensorInfo0, targetShape, reshapeDesc) != kTfLiteOk)
154     {
155         TF_LITE_MAYBE_KERNEL_LOG(tfLiteContext,
156                                  "TfLiteArmnnDelegate: At most one component of shape can be -1 in: "
157                                  "operator #%d node #%d: ",
158                                  operatorCode, nodeIndex);
159         return kTfLiteError;
160     }
161 
162     if (reshapeDesc.m_TargetShape.GetNumElements() != inputTensorInfo0.GetNumElements())
163     {
164         TF_LITE_MAYBE_KERNEL_LOG(
165             tfLiteContext,
166             "TfLiteArmnnDelegate: Reshape, number of elements in output shape does not match input "
167             "operator #%d node #%d: ",
168             operatorCode, nodeIndex);
169         return kTfLiteError;
170     }
171 
172     bool isSupported = false;
173     auto validateFunc = [&](const armnn::TensorInfo& outInfo, bool& isSupported)
174     {
175         FORWARD_LAYER_SUPPORT_FUNC(__func__,
176                                    tfLiteContext,
177                                    IsReshapeSupported,
178                                    delegateData.m_Backends,
179                                    isSupported,
180                                    inputTensorInfo0,
181                                    outInfo,
182                                    reshapeDesc);
183     };
184 
185     if (!delegateData.m_Network)
186     {
187         validateFunc(outputTensorInfo, isSupported);
188         return isSupported ? kTfLiteOk : kTfLiteError;
189     }
190 
191     armnn::IConnectableLayer* layer = delegateData.m_Network->AddReshapeLayer(reshapeDesc);
192     ARMNN_ASSERT(layer != nullptr);
193 
194     armnn::IOutputSlot& outputSlot = layer->GetOutputSlot(0);
195     outputSlot.SetTensorInfo(outputTensorInfo);
196 
197     // Connect
198     return Connect(layer, tfLiteNode, delegateData);
199 }
200 
VisitSqueezeOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)201 TfLiteStatus VisitSqueezeOperator(DelegateData& delegateData,
202                                   TfLiteContext* tfLiteContext,
203                                   TfLiteNode* tfLiteNode,
204                                   int nodeIndex,
205                                   int32_t operatorCode)
206 {
207     armnn::IgnoreUnused(delegateData,
208                         tfLiteContext,
209                         tfLiteNode,
210                         nodeIndex,
211                         operatorCode);
212 
213     return kTfLiteError;
214 }
215 
VisitExpandDimsOperator(DelegateData & delegateData,TfLiteContext * tfLiteContext,TfLiteNode * tfLiteNode,int nodeIndex,int32_t operatorCode)216 TfLiteStatus VisitExpandDimsOperator(DelegateData& delegateData,
217                                      TfLiteContext* tfLiteContext,
218                                      TfLiteNode* tfLiteNode,
219                                      int nodeIndex,
220                                      int32_t operatorCode)
221 {
222     armnn::IgnoreUnused(delegateData,
223                         tfLiteContext,
224                         tfLiteNode,
225                         nodeIndex,
226                         operatorCode);
227 
228     return kTfLiteError;
229 }
230 
231 } // namespace armnnDelegate
232