• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16 
17 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/core/api/profiler.h"
28 #include "tensorflow/lite/delegates/flex/delegate.h"
29 #include "tensorflow/lite/delegates/flex/delegate_data.h"
30 #include "tensorflow/lite/delegates/flex/util.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/minimal_logging.h"
33 #include "tensorflow/lite/string_type.h"
34 
35 // Note: this is part of TF Lite's Flex delegation code which is to be
36 // completed soon.
37 
38 // This is the TF Lite op that is created by the flex delegate to handle
39 // execution of a supported subgraph. The usual flow is that the delegate
40 // informs the interpreter of supported nodes in a graph, and each supported
41 // subgraph is replaced with one instance of this kernel.
42 //
43 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
44 // the global EagerContext and BufferMap, as well as a list of inputs and
45 // outputs to the subgraph. Those are used to build the OpData, with a list of
46 // TensorFlow Ops that should be executed in order (which we call an OpNode).
47 //
48 // For each node included in the subgraph, we query the interpreter and
49 // retrieve the associated NodeDef, which is then used to configure the
50 // corresponding TensorFlow/Eager Op.
51 
52 using tensorflow::shape_inference::DimensionHandle;
53 using tensorflow::shape_inference::InferenceContext;
54 using tensorflow::shape_inference::ShapeAndType;
55 using tensorflow::shape_inference::ShapeHandle;
56 
GetDimsDebugString(const TfLiteIntArray * dims)57 const std::string GetDimsDebugString(const TfLiteIntArray* dims) {
58   return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","),
59                       "]");
60 }
61 
62 namespace tflite {
63 namespace flex {
64 
65 struct OpNode;
66 
67 // Represents the origin of a given tensor as a reference to the output
68 // of an upstream node.
69 struct TensorSource {
70   OpNode* node;
71   int node_output_index;
72 };
73 
74 // A list of inputs of a given node of the TensorFlow/Eager graph.
75 class OpInputs {
76  public:
OpInputs(const TfLiteIntArray * indexes)77   explicit OpInputs(const TfLiteIntArray* indexes) {
78     for (int index : TfLiteIntArrayView(indexes)) {
79       inputs_.push_back(index);
80     }
81     forwardable_.resize(inputs_.size());
82   }
~OpInputs()83   ~OpInputs() {}
84 
Size() const85   int Size() const { return inputs_.size(); }
86 
TfLiteIndex(int i) const87   int TfLiteIndex(int i) const { return inputs_[i]; }
88 
89   // Given a map relating tensors to the node that originates them, populate a
90   // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)91   void InitializeTensorSources(
92       const std::map<int, TensorSource>& tflite_tensor_sources) {
93     sources_.clear();
94     for (int i : inputs_) {
95       auto it = tflite_tensor_sources.find(i);
96       if (it == tflite_tensor_sources.end()) {
97         sources_.push_back({nullptr, 0});
98       } else {
99         sources_.push_back(it->second);
100       }
101     }
102   }
103 
SetForwardable(int i,bool v)104   void SetForwardable(int i, bool v) { forwardable_[i] = v; }
105 
IsForwardable(int i) const106   bool IsForwardable(int i) const { return forwardable_[i]; }
107 
GetTensorSource(int i) const108   TensorSource GetTensorSource(int i) const { return sources_[i]; }
109 
110  private:
111   std::vector<int> inputs_;
112   std::vector<TensorSource> sources_;
113 
114   // List of tensors that can be used by TF in its forwarding optimization.
115   // Doing so allows an input tensor to be modified and used as the output
116   // tensor. The delegate takes care of not holding any references to tensors
117   // in this list while Eager is executing the corresponding op.
118   std::vector<int> forwardable_;
119 };
120 
121 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
122 // the actual outputs of the EagerOperation.
123 class OpOutputs {
124  public:
OpOutputs(const TfLiteIntArray * indexes)125   explicit OpOutputs(const TfLiteIntArray* indexes) {
126     for (int index : TfLiteIntArrayView(indexes)) {
127       outputs_.push_back(index);
128     }
129     vector_.resize(outputs_.size());
130   }
~OpOutputs()131   ~OpOutputs() { ResetTensorHandles(); }
132 
133   // Stores information about which of the tensors in this class are also
134   // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)135   void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
136     subgraph_outputs_.clear();
137     for (int i : outputs_) {
138       subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
139     }
140   }
141 
142   // Returns true if the tensor given by index 'i' is an output of the entire
143   // subgraph.
IsSubgraphOutput(int i) const144   bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
145 
146   // Returns a handle to a given tensor and, optionally, remove it from the
147   // internal vector.
GetHandle(int i,bool remove)148   tensorflow::TensorHandle* GetHandle(int i, bool remove) {
149     auto* handle = vector_[i];
150     if (!remove) {
151       handle->Ref();
152     } else {
153       // Don't increase the ref-count. Instead, simply take it out of the
154       // vector.
155       vector_[i] = nullptr;
156     }
157     return handle;
158   }
159 
Size() const160   int Size() const { return outputs_.size(); }
161 
TfLiteIndex(int i) const162   int TfLiteIndex(int i) const { return outputs_[i]; }
163 
164   // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()165   void ResetTensorHandles() {
166     for (int i = 0; i < vector_.size(); ++i) {
167       if (vector_[i]) {
168         vector_[i]->Unref();
169         vector_[i] = nullptr;
170       }
171     }
172   }
173 
174   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()175   GetTensorHandles() {
176     return &vector_;
177   }
178 
179  private:
180   std::vector<int> outputs_;
181   std::vector<bool> subgraph_outputs_;
182   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
183 };
184 
185 // A single node within the larger 'op'. Note that this kernel executes many
186 // TensorFlow ops within a single TF Lite op.
187 class OpNode {
188  public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)189   OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
190       : inputs_(inputs), outputs_(outputs) {}
~OpNode()191   ~OpNode() {
192     if (op_) ClearEagerInputs();
193   }
194 
name() const195   const string& name() const { return name_; }
set_name(const string & name)196   void set_name(const string& name) { name_ = name; }
197 
index() const198   int index() const { return index_; }
set_index(int index)199   void set_index(int index) { index_ = index; }
200 
nodedef() const201   const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const202   const tensorflow::OpRegistrationData* op_reg_data() const {
203     return op_reg_data_;
204   }
205 
inputs() const206   const OpInputs& inputs() const { return inputs_; }
mutable_inputs()207   OpInputs* mutable_inputs() { return &inputs_; }
208 
outputs() const209   const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()210   OpOutputs* mutable_outputs() { return &outputs_; }
211 
NumInputs() const212   int NumInputs() const { return inputs_.Size(); }
NumOutputs() const213   int NumOutputs() const { return outputs_.Size(); }
214 
op()215   tensorflow::EagerOperation* op() { return op_.get(); }
216 
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)217   tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
218                                        int custom_initial_data_size) {
219     if (!custom_initial_data) {
220       return tensorflow::errors::Internal(
221           "Cannot convert empty data into a valid NodeDef");
222     }
223     // The flexbuffer contains a vector where the first elements is the
224     // op name and the second is a serialized NodeDef.
225     const flexbuffers::Vector& v =
226         flexbuffers::GetRoot(
227             reinterpret_cast<const uint8_t*>(custom_initial_data),
228             custom_initial_data_size)
229             .AsVector();
230 
231     name_ = v[0].AsString().str();
232     if (!nodedef_.ParseFromString(v[1].AsString().str())) {
233       nodedef_.Clear();
234       return tensorflow::errors::Internal(
235           "Failed to parse data into a valid NodeDef");
236     }
237 
238     // Fill NodeDef with defaults if it's a valid op.
239     TF_RETURN_IF_ERROR(
240         tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
241     AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
242 
243     return tensorflow::Status::OK();
244   }
245 
246   // Build thew new EagerOperation. In case of error, the returned 'op' is
247   // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context,tensorflow::CancellationManager * cancellation_manager)248   tensorflow::Status BuildEagerOp(
249       tensorflow::EagerContext* eager_context,
250       tensorflow::CancellationManager* cancellation_manager) {
251     op_.reset(new tensorflow::EagerOperation(eager_context));
252     TF_RETURN_IF_ERROR(op_->Reset(name_.c_str(), nullptr, false, nullptr));
253     if (op_->is_function()) {
254       op_.reset();
255       return tensorflow::errors::NotFound(
256           "Operation '", name_,
257           "' is not registered.  (while processing attributes of '", name_,
258           "')");
259     }
260 
261     op_->MutableAttrs()->NumInputs(inputs_.Size());
262     for (const auto& attr : nodedef_.attr()) {
263       op_->MutableAttrs()->Set(attr.first, attr.second);
264     }
265 
266     // Precalculating a cache key saves about 10% of inference time for very
267     // small models.
268     op_->MutableAttrs()->CacheKey(op_->DeviceName());
269 
270     op_->SetCancellationManager(cancellation_manager);
271 
272     return tensorflow::Status::OK();
273   }
274 
ClearEagerInputs()275   void ClearEagerInputs() { op_->Clear(); }
276 
BuildEagerInputs(const BufferMap * buffer_map)277   tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
278     absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
279     TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
280     for (int i = 0; i < inputs_.Size(); ++i) {
281       int input_index = inputs_.TfLiteIndex(i);
282       TensorSource s = inputs_.GetTensorSource(i);
283       if (!s.node) {
284         // This input is not produced by this Eager subgraph (it could be a TF
285         // Lite native buffer, or could be produced by a separater subgraph). We
286         // need to fetch it from the delegate's buffer_map.
287         if (!buffer_map->HasTensor(input_index)) {
288           return tensorflow::errors::Internal(
289               "Cannot read from invalid tensor index ", input_index);
290         }
291         tensorflow::TensorHandle* handle =
292             tensorflow::TensorHandle::CreateLocalHandle(
293                 buffer_map->GetTensor(input_index));
294         op_inputs->push_back(handle);
295       } else {
296         // If this is a forwardable tensor, we will remove it from the previous
297         // op's list, giving TF the opportunity to reuse its buffer.
298         bool unref_handle = inputs_.IsForwardable(i);
299         auto* handle =
300             s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
301         op_inputs->push_back(handle);
302       }
303     }
304     return tensorflow::Status::OK();
305   }
306 
PersistEagerOutputs(BufferMap * buffer_map)307   tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
308     auto* handles = outputs_.GetTensorHandles();
309     for (int i = 0; i < outputs_.Size(); ++i) {
310       if (outputs_.IsSubgraphOutput(i)) {
311         const tensorflow::Tensor* tensor = nullptr;
312         TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
313         buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
314       }
315     }
316     return tensorflow::Status::OK();
317   }
318 
319  private:
320   OpNode(const OpNode&) = delete;
321   OpNode& operator=(const OpNode&) = delete;
322 
323   // The name of the TensorFlow op to execute.
324   string name_;
325   // Index of this node into TF Lite's operator list.
326   int index_;
327   // The corresponding NodeDef, containing the attributes for the op.
328   tensorflow::NodeDef nodedef_;
329   // The corresponding OpRegistrationData pointer.
330   const tensorflow::OpRegistrationData* op_reg_data_;
331   // List of inputs, as TF Lite tensor indices.
332   OpInputs inputs_;
333   // List of outputs, as TF Lite tensor indices.
334   OpOutputs outputs_;
335 
336   std::unique_ptr<tensorflow::EagerOperation> op_;
337 };
338 
339 // Executes the TensorFlow op given by 'op_name', with the attributes specified
340 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)341 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
342                                  OpNode* node_data) {
343   TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
344                                   " (while executing '", node_data->name(),
345                                   "' via Eager)");
346 
347   node_data->mutable_outputs()->ResetTensorHandles();
348   int num_retvals = node_data->NumOutputs();
349   TF_RETURN_WITH_CONTEXT_IF_ERROR(
350       node_data->op()->Execute(
351           absl::MakeSpan(
352               reinterpret_cast<tensorflow::AbstractTensorHandle**>(
353                   node_data->mutable_outputs()->GetTensorHandles()->data()),
354               num_retvals),
355           &num_retvals),
356       " (while executing '", node_data->name(), "' via Eager)");
357 
358   if (num_retvals != node_data->NumOutputs()) {
359     return tensorflow::errors::Internal(
360         "Unexpected number of outputs from EagerExecute");
361   }
362 
363   TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
364 
365   node_data->ClearEagerInputs();
366 
367   return tensorflow::Status::OK();
368 }
369 
370 // The larger 'op', which contains all the nodes in a supported subgraph.
371 struct OpData {
372   tensorflow::EagerContext* eager_context;
373   tensorflow::CancellationManager* cancellation_manager;
374   BufferMap* buffer_map;
375   std::vector<std::unique_ptr<OpNode>> nodes;
376   std::vector<int> subgraph_inputs;
377   std::vector<int> subgraph_outputs;
378 };
379 
DelegateKernel()380 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()381 DelegateKernel::~DelegateKernel() {}
382 
Init(TfLiteContext * context,const TfLiteDelegateParams * params)383 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
384                                   const TfLiteDelegateParams* params) {
385   auto* flex_delegate_data =
386       reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
387   op_data_->eager_context = flex_delegate_data->GetEagerContext();
388   op_data_->cancellation_manager = flex_delegate_data->GetCancellationManager();
389   op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
390 
391   CHECK(params->output_tensors);
392   std::set<int> output_set;
393   for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
394     op_data_->subgraph_outputs.push_back(tensor_index);
395     output_set.insert(tensor_index);
396   }
397 
398   CHECK(params->input_tensors);
399   for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
400     op_data_->subgraph_inputs.push_back(tensor_index);
401   }
402 
403   op_data_->nodes.reserve(params->nodes_to_replace->size);
404 
405   CHECK(params->nodes_to_replace);
406   tensorflow::Status status;
407   for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
408     TfLiteNode* node;
409     TfLiteRegistration* reg;
410     context->GetNodeAndRegistration(context, node_index, &node, &reg);
411 
412     op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
413     OpNode& node_data = *op_data_->nodes.back();
414 
415     node_data.set_index(node_index);
416     node_data.set_name("");
417 
418     status = node_data.InitializeNodeDef(node->custom_initial_data,
419                                          node->custom_initial_data_size);
420     if (!status.ok()) break;
421     status = node_data.BuildEagerOp(op_data_->eager_context,
422                                     op_data_->cancellation_manager);
423     if (!status.ok()) break;
424   }
425 
426   TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
427 
428   // Given a TfLite tensor index, return the OpNode that produces it,
429   // along with it index into that OpNodes list of outputs.
430   std::map<int, TensorSource> tflite_tensor_sources;
431 
432   // Find out how each tensor is produced. This does not account for
433   // tensors that are not produce by eager ops.
434   for (auto& node_data : op_data_->nodes) {
435     node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
436     for (int i = 0; i < node_data->outputs().Size(); ++i) {
437       int output_index = node_data->outputs().TfLiteIndex(i);
438       tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
439     }
440   }
441 
442   // For each node, resolve the inputs, so we can keep pointers to the nodes
443   // that produces them.
444   for (auto& node_data : op_data_->nodes) {
445     node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
446   }
447   return kTfLiteOk;
448 }
449 
Prepare(TfLiteContext * context,TfLiteNode * node)450 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
451   TF_LITE_ENSURE_MSG(
452       context, op_data_->eager_context != nullptr,
453       "Failed to initialize eager context. This often happens when a CPU "
454       "device has not been registered, presumably because some symbols from "
455       "tensorflow/core:core_cpu_impl were not linked into the binary.");
456 
457   // We will keep track of the number of references to each tensor in the
458   // graph, so we can make them "forwardable" if there is only one reference.
459   std::map<int, int> tensor_ref_count;
460 
461   // Whenever we find a constant tensor, insert it in the buffer map.
462   BufferMap* buffer_map = op_data_->buffer_map;
463   for (auto tensor_index : op_data_->subgraph_inputs) {
464     TfLiteTensor* tensor = &context->tensors[tensor_index];
465     if (IsConstantTensor(tensor)) {
466       if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
467         buffer_map->SetFromTfLite(tensor_index, tensor);
468       }
469     }
470 
471     // Input tensors should never be forwarded so we increment their ref counts
472     // twice: once for this graph and another for the possibility of them being
473     // used by another subgraph, or being an output of the full graph.
474     tensor_ref_count[tensor_index] += 2;
475   }
476 
477   const bool shapes_are_valid =
478       (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
479   if (shapes_are_valid) {
480     TFLITE_LOG(tflite::TFLITE_LOG_INFO,
481                "FlexDelegate: All tensor shapes are consistent.");
482   } else {
483     TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
484                "FlexDelegate: Some tensor shapes are inconsistent.");
485   }
486 
487   // All output tensors are allocated by TensorFlow/Eager, so we
488   // mark them as kTfLiteDynamic.
489   for (auto tensor_index : op_data_->subgraph_outputs) {
490     if (!shapes_are_valid) {
491       SetTensorToDynamic(&context->tensors[tensor_index]);
492     }
493     ++tensor_ref_count[tensor_index];
494   }
495 
496   for (const auto& node_data : op_data_->nodes) {
497     if (node_data->nodedef().op().empty()) {
498       context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
499                            node_data->name().c_str());
500       return kTfLiteError;
501     }
502     TF_LITE_ENSURE(context, node_data->op());
503 
504     for (int i = 0; i < node_data->inputs().Size(); ++i) {
505       ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
506     }
507   }
508 
509   // All tensors that are referenced exactly once are marked as "forwardable",
510   // meaning that we will allow TensorFlow to reuse its buffer as the output of
511   // an op.
512   for (auto& node_data : op_data_->nodes) {
513     for (int i = 0; i < node_data->inputs().Size(); ++i) {
514       bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
515       node_data->mutable_inputs()->SetForwardable(i, f);
516     }
517   }
518 
519   return kTfLiteOk;
520 }
521 
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const522 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
523     TfLiteContext* context) const {
524   for (const auto& node_data : op_data_->nodes) {
525     auto op_name = node_data->name().c_str();
526     // Create an InferenceContext object.
527     auto num_inputs = node_data->inputs().Size();
528     std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
529                                                                 nullptr);
530     InferenceContext c(
531         TF_GRAPH_DEF_VERSION, node_data->nodedef(),
532         node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
533         input_tensors_vector, {},
534         std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
535 
536     // Set input_shapes for ShapeInferenceFn.
537     for (int i = 0; i < num_inputs; ++i) {
538       const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
539       TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
540       // Provide constant input tensors since some op ("RFFT") needs it to
541       // calculate the output shape.
542       if (IsConstantTensor(tfl_tensor)) {
543         input_tensors_vector[i] =
544             op_data_->buffer_map->GetTensorPtr(input_tensor_index);
545       }
546       const auto dims_array = tfl_tensor->dims;
547       std::vector<DimensionHandle> dims(dims_array->size);
548       for (int j = 0; j < dims_array->size; ++j) {
549         dims[j] = c.MakeDim(dims_array->data[j]);
550       }
551       c.SetInput(i, c.MakeShape(dims));
552     }
553     c.set_input_tensors(input_tensors_vector);
554 
555     tensorflow::Status status = c.construction_status();
556     if (!status.ok()) {
557       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
558                  "Shape construction failed for op '%s'", op_name);
559       return kTfLiteError;
560     }
561 
562     // Run ShapeInferenceFn to calculate output shapes.
563     if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
564       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
565                  "No shape inference function exists for op '%s'", op_name);
566       return kTfLiteError;
567     }
568     status = c.Run(node_data->op_reg_data()->shape_inference_fn);
569 
570     // Compare calculated output shapes with node_data->outputs
571     auto num_outputs = node_data->outputs().Size();
572     if (num_outputs != c.num_outputs()) {
573       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
574                  "Number of output tensors are mismatched for op '%s' %d != %d",
575                  op_name, num_outputs, c.num_outputs());
576       return kTfLiteError;
577     }
578     for (int i = 0; i < num_outputs; ++i) {
579       const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
580       TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
581       // tfl_tensor->dims only has valid information if the given model is
582       // converted by the MLIR converter. Also when ResizeInputTensor() is
583       // called the dims information becomes invalid.
584       const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims);
585       const std::string calculated_shape_string = c.DebugString(c.output(i));
586       // Getting a shape string via c.DebugString() is the easiest way to get
587       // the shape information of the given ShapeHandle for now.
588       // TODO(b/169017408): Find a better approach without using debug string.
589       if (tfl_shape_string != calculated_shape_string) {
590         TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
591                    "op '%s' output%d tensor#%d shape mismatch for  %s != %s",
592                    op_name, i, output_tensor_index, tfl_shape_string.c_str(),
593                    calculated_shape_string.c_str());
594         return kTfLiteError;
595       }
596     }
597   }
598   return kTfLiteOk;
599 }
600 
Eval(TfLiteContext * context,TfLiteNode * node)601 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
602   BufferMap* buffer_map = op_data_->buffer_map;
603 
604   // Insert a tensor in the buffer map for all inputs that are not constant.
605   // Constants were handled in Prepare() already.
606   for (auto tensor_index : op_data_->subgraph_inputs) {
607     TfLiteTensor* tensor = &context->tensors[tensor_index];
608     if (!IsConstantTensor(tensor)) {
609       // If this tensor is part of an earlier TF subgraph we should not add it
610       // to the BufferMap again, because TF already knows about it and its
611       // contents are kept automatically up-to-date.
612       if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
613         buffer_map->SetFromTfLite(tensor_index, tensor);
614       }
615     }
616   }
617 
618   // Execute the TensorFlow Ops sequentially.
619   for (auto& node_data : op_data_->nodes) {
620     TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
621         reinterpret_cast<Profiler*>(context->profiler),
622         node_data->name().c_str(), node_data->index());
623 
624     if (op_data_->cancellation_manager != nullptr &&
625         op_data_->cancellation_manager->IsCancelled()) {
626       TF_LITE_KERNEL_LOG(context, "Client requested cancel during Invoke()");
627       return kTfLiteError;
628     }
629 
630     auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
631     TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
632   }
633 
634   for (auto tensor_index : op_data_->subgraph_outputs) {
635     if (!buffer_map->HasTensor(tensor_index)) {
636       context->ReportError(context, "Cannot write to invalid tensor index %d",
637                            tensor_index);
638       return kTfLiteError;
639     }
640 
641     // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
642     // For dynamic tensors, copy shape and put buffer_handle for the later
643     // CopyFromBufferHandle() call.
644     TfLiteTensor* tensor = &context->tensors[tensor_index];
645     const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
646     if (tensor->allocation_type == kTfLiteDynamic) {
647       TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
648       tensor->buffer_handle = tensor_index;
649       tensor->data_is_stale = true;
650       continue;
651     }
652     // If the tensor isn't dynamic, we can copy data directly to the buffer of
653     // the tensor. Before copying the data, check if the target buffer has
654     // expected size.
655     if (tf_tensor.NumElements() != NumElements(tensor) ||
656         tf_tensor.TotalBytes() != tensor->bytes) {
657       TF_LITE_KERNEL_LOG(
658           context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)",
659           tensor->name, tensor_index, tf_tensor.TotalBytes(),
660           tf_tensor.NumElements(), tensor->bytes, NumElements(tensor));
661       return kTfLiteError;
662     }
663     tensorflow::StringPiece t_data = tf_tensor.tensor_data();
664     memcpy(tensor->data.raw, t_data.data(), t_data.size());
665   }
666 
667   return kTfLiteOk;
668 }
669 
670 }  // namespace flex
671 }  // namespace tflite
672