• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16 
17 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/core/api/profiler.h"
28 #include "tensorflow/lite/delegates/flex/delegate.h"
29 #include "tensorflow/lite/delegates/flex/delegate_data.h"
30 #include "tensorflow/lite/delegates/flex/util.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/minimal_logging.h"
33 #include "tensorflow/lite/string_type.h"
34 
35 // Note: this is part of TF Lite's Flex delegation code which is to be
36 // completed soon.
37 
38 // This is the TF Lite op that is created by the flex delegate to handle
39 // execution of a supported subgraph. The usual flow is that the delegate
40 // informs the interpreter of supported nodes in a graph, and each supported
41 // subgraph is replaced with one instance of this kernel.
42 //
43 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
44 // the global EagerContext and BufferMap, as well as a list of inputs and
45 // outputs to the subgraph. Those are used to build the OpData, with a list of
46 // TensorFlow Ops that should be executed in order (which we call an OpNode).
47 //
48 // For each node included in the subgraph, we query the interpreter and
49 // retrieve the associated NodeDef, which is then used to configure the
50 // corresponding TensorFlow/Eager Op.
51 
52 using tensorflow::shape_inference::DimensionHandle;
53 using tensorflow::shape_inference::InferenceContext;
54 using tensorflow::shape_inference::ShapeAndType;
55 using tensorflow::shape_inference::ShapeHandle;
56 
GetDimsDebugString(const TfLiteIntArray * dims)57 const std::string GetDimsDebugString(const TfLiteIntArray* dims) {
58   return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","),
59                       "]");
60 }
61 
62 namespace tflite {
63 namespace flex {
64 
65 struct OpNode;
66 
67 // Represents the origin of a given tensor as a reference to the output
68 // of an upstream node.
69 struct TensorSource {
70   OpNode* node;
71   int node_output_index;
72 };
73 
74 // A list of inputs of a given node of the TensorFlow/Eager graph.
75 class OpInputs {
76  public:
OpInputs(const TfLiteIntArray * indexes)77   explicit OpInputs(const TfLiteIntArray* indexes) {
78     for (int index : TfLiteIntArrayView(indexes)) {
79       inputs_.push_back(index);
80     }
81     forwardable_.resize(inputs_.size());
82   }
~OpInputs()83   ~OpInputs() {}
84 
Size() const85   int Size() const { return inputs_.size(); }
86 
TfLiteIndex(int i) const87   int TfLiteIndex(int i) const { return inputs_[i]; }
88 
89   // Given a map relating tensors to the node that originates them, populate a
90   // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)91   void InitializeTensorSources(
92       const std::map<int, TensorSource>& tflite_tensor_sources) {
93     sources_.clear();
94     for (int i : inputs_) {
95       auto it = tflite_tensor_sources.find(i);
96       if (it == tflite_tensor_sources.end()) {
97         sources_.push_back({nullptr, 0});
98       } else {
99         sources_.push_back(it->second);
100       }
101     }
102   }
103 
SetForwardable(int i,bool v)104   void SetForwardable(int i, bool v) { forwardable_[i] = v; }
105 
IsForwardable(int i) const106   bool IsForwardable(int i) const { return forwardable_[i]; }
107 
GetTensorSource(int i) const108   TensorSource GetTensorSource(int i) const { return sources_[i]; }
109 
110  private:
111   std::vector<int> inputs_;
112   std::vector<TensorSource> sources_;
113 
114   // List of tensors that can be used by TF in its forwarding optimization.
115   // Doing so allows an input tensor to be modified and used as the output
116   // tensor. The delegate takes care of not holding any references to tensors
117   // in this list while Eager is executing the corresponding op.
118   std::vector<int> forwardable_;
119 };
120 
121 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
122 // the actual outputs of the EagerOperation.
123 class OpOutputs {
124  public:
OpOutputs(const TfLiteIntArray * indexes)125   explicit OpOutputs(const TfLiteIntArray* indexes) {
126     for (int index : TfLiteIntArrayView(indexes)) {
127       outputs_.push_back(index);
128     }
129     vector_.resize(outputs_.size());
130   }
~OpOutputs()131   ~OpOutputs() { ResetTensorHandles(); }
132 
133   // Stores information about which of the tensors in this class are also
134   // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)135   void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
136     subgraph_outputs_.clear();
137     for (int i : outputs_) {
138       subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
139     }
140   }
141 
142   // Returns true if the tensor given by index 'i' is an output of the entire
143   // subgraph.
IsSubgraphOutput(int i) const144   bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
145 
146   // Returns a handle to a given tensor and, optionally, remove it from the
147   // internal vector.
GetHandle(int i,bool remove)148   tensorflow::TensorHandle* GetHandle(int i, bool remove) {
149     auto* handle = vector_[i];
150     if (!remove) {
151       handle->Ref();
152     } else {
153       // Don't increase the ref-count. Instead, simply take it out of the
154       // vector.
155       vector_[i] = nullptr;
156     }
157     return handle;
158   }
159 
Size() const160   int Size() const { return outputs_.size(); }
161 
TfLiteIndex(int i) const162   int TfLiteIndex(int i) const { return outputs_[i]; }
163 
164   // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()165   void ResetTensorHandles() {
166     for (int i = 0; i < vector_.size(); ++i) {
167       if (vector_[i]) {
168         vector_[i]->Unref();
169         vector_[i] = nullptr;
170       }
171     }
172   }
173 
174   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()175   GetTensorHandles() {
176     return &vector_;
177   }
178 
179  private:
180   std::vector<int> outputs_;
181   std::vector<bool> subgraph_outputs_;
182   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
183 };
184 
185 // A single node within the larger 'op'. Note that this kernel executes many
186 // TensorFlow ops within a single TF Lite op.
187 class OpNode {
188  public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)189   OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
190       : inputs_(inputs), outputs_(outputs) {}
~OpNode()191   ~OpNode() {
192     if (op_) ClearEagerInputs();
193   }
194 
name() const195   const string& name() const { return name_; }
set_name(const string & name)196   void set_name(const string& name) { name_ = name; }
197 
index() const198   int index() const { return index_; }
set_index(int index)199   void set_index(int index) { index_ = index; }
200 
nodedef() const201   const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const202   const tensorflow::OpRegistrationData* op_reg_data() const {
203     return op_reg_data_;
204   }
205 
inputs() const206   const OpInputs& inputs() const { return inputs_; }
mutable_inputs()207   OpInputs* mutable_inputs() { return &inputs_; }
208 
outputs() const209   const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()210   OpOutputs* mutable_outputs() { return &outputs_; }
211 
NumInputs() const212   int NumInputs() const { return inputs_.Size(); }
NumOutputs() const213   int NumOutputs() const { return outputs_.Size(); }
214 
op()215   tensorflow::EagerOperation* op() { return op_.get(); }
216 
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)217   tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
218                                        int custom_initial_data_size) {
219     if (!custom_initial_data) {
220       return tensorflow::errors::Internal(
221           "Cannot convert empty data into a valid NodeDef");
222     }
223     // The flexbuffer contains a vector where the first elements is the
224     // op name and the second is a serialized NodeDef.
225     const flexbuffers::Vector& v =
226         flexbuffers::GetRoot(
227             reinterpret_cast<const uint8_t*>(custom_initial_data),
228             custom_initial_data_size)
229             .AsVector();
230 
231     name_ = v[0].AsString().str();
232     if (!nodedef_.ParseFromString(v[1].AsString().str())) {
233       nodedef_.Clear();
234       return tensorflow::errors::Internal(
235           "Failed to parse data into a valid NodeDef");
236     }
237 
238     // Fill NodeDef with defaults if it's a valid op.
239     TF_RETURN_IF_ERROR(
240         tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
241     AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
242 
243     return tensorflow::Status::OK();
244   }
245 
246   // Build thew new EagerOperation. In case of error, the returned 'op' is
247   // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context)248   tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) {
249     op_.reset(new tensorflow::EagerOperation(eager_context));
250     TF_RETURN_IF_ERROR(op_->Reset(name_.c_str(), nullptr, false, nullptr));
251     if (op_->is_function()) {
252       op_.reset();
253       return tensorflow::errors::NotFound(
254           "Operation '", name_,
255           "' is not registered.  (while processing attributes of '", name_,
256           "')");
257     }
258 
259     op_->MutableAttrs()->NumInputs(inputs_.Size());
260     for (const auto& attr : nodedef_.attr()) {
261       op_->MutableAttrs()->Set(attr.first, attr.second);
262     }
263 
264     // Precalculating a cache key saves about 10% of inference time for very
265     // small models.
266     op_->MutableAttrs()->CacheKey(op_->DeviceName());
267 
268     return tensorflow::Status::OK();
269   }
270 
ClearEagerInputs()271   void ClearEagerInputs() { op_->Clear(); }
272 
BuildEagerInputs(const BufferMap * buffer_map)273   tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
274     absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
275     TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
276     for (int i = 0; i < inputs_.Size(); ++i) {
277       int input_index = inputs_.TfLiteIndex(i);
278       TensorSource s = inputs_.GetTensorSource(i);
279       if (!s.node) {
280         // This input is not produced by this Eager subgraph (it could be a TF
281         // Lite native buffer, or could be produced by a separater subgraph). We
282         // need to fetch it from the delegate's buffer_map.
283         if (!buffer_map->HasTensor(input_index)) {
284           return tensorflow::errors::Internal(
285               "Cannot read from invalid tensor index ", input_index);
286         }
287         tensorflow::TensorHandle* handle =
288             tensorflow::TensorHandle::CreateLocalHandle(
289                 buffer_map->GetTensor(input_index));
290         op_inputs->push_back(handle);
291       } else {
292         // If this is a forwardable tensor, we will remove it from the previous
293         // op's list, giving TF the opportunity to reuse its buffer.
294         bool unref_handle = inputs_.IsForwardable(i);
295         auto* handle =
296             s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
297         op_inputs->push_back(handle);
298       }
299     }
300     return tensorflow::Status::OK();
301   }
302 
PersistEagerOutputs(BufferMap * buffer_map)303   tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
304     auto* handles = outputs_.GetTensorHandles();
305     for (int i = 0; i < outputs_.Size(); ++i) {
306       if (outputs_.IsSubgraphOutput(i)) {
307         const tensorflow::Tensor* tensor = nullptr;
308         TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
309         buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
310       }
311     }
312     return tensorflow::Status::OK();
313   }
314 
315  private:
316   OpNode(const OpNode&) = delete;
317   OpNode& operator=(const OpNode&) = delete;
318 
319   // The name of the TensorFlow op to execute.
320   string name_;
321   // Index of this node into TF Lite's operator list.
322   int index_;
323   // The corresponding NodeDef, containing the attributes for the op.
324   tensorflow::NodeDef nodedef_;
325   // The corresponding OpRegistrationData pointer.
326   const tensorflow::OpRegistrationData* op_reg_data_;
327   // List of inputs, as TF Lite tensor indices.
328   OpInputs inputs_;
329   // List of outputs, as TF Lite tensor indices.
330   OpOutputs outputs_;
331 
332   std::unique_ptr<tensorflow::EagerOperation> op_;
333 };
334 
335 // Executes the TensorFlow op given by 'op_name', with the attributes specified
336 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)337 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
338                                  OpNode* node_data) {
339   TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
340                                   " (while executing '", node_data->name(),
341                                   "' via Eager)");
342 
343   node_data->mutable_outputs()->ResetTensorHandles();
344   int num_retvals = node_data->NumOutputs();
345   TF_RETURN_WITH_CONTEXT_IF_ERROR(
346       node_data->op()->Execute(
347           absl::MakeSpan(
348               reinterpret_cast<tensorflow::AbstractTensorHandle**>(
349                   node_data->mutable_outputs()->GetTensorHandles()->data()),
350               num_retvals),
351           &num_retvals),
352       " (while executing '", node_data->name(), "' via Eager)");
353 
354   if (num_retvals != node_data->NumOutputs()) {
355     return tensorflow::errors::Internal(
356         "Unexpected number of outputs from EagerExecute");
357   }
358 
359   TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
360 
361   node_data->ClearEagerInputs();
362 
363   return tensorflow::Status::OK();
364 }
365 
366 // The larger 'op', which contains all the nodes in a supported subgraph.
367 struct OpData {
368   tensorflow::EagerContext* eager_context;
369   BufferMap* buffer_map;
370   std::vector<std::unique_ptr<OpNode>> nodes;
371   std::vector<int> subgraph_inputs;
372   std::vector<int> subgraph_outputs;
373 };
374 
DelegateKernel()375 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()376 DelegateKernel::~DelegateKernel() {}
377 
Init(TfLiteContext * context,const TfLiteDelegateParams * params)378 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
379                                   const TfLiteDelegateParams* params) {
380   auto* flex_delegate_data =
381       reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
382   op_data_->eager_context = flex_delegate_data->GetEagerContext();
383   op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
384 
385   CHECK(params->output_tensors);
386   std::set<int> output_set;
387   for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
388     op_data_->subgraph_outputs.push_back(tensor_index);
389     output_set.insert(tensor_index);
390   }
391 
392   CHECK(params->input_tensors);
393   for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
394     op_data_->subgraph_inputs.push_back(tensor_index);
395   }
396 
397   op_data_->nodes.reserve(params->nodes_to_replace->size);
398 
399   CHECK(params->nodes_to_replace);
400   tensorflow::Status status;
401   for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
402     TfLiteNode* node;
403     TfLiteRegistration* reg;
404     context->GetNodeAndRegistration(context, node_index, &node, &reg);
405 
406     op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
407     OpNode& node_data = *op_data_->nodes.back();
408 
409     node_data.set_index(node_index);
410     node_data.set_name("");
411 
412     status = node_data.InitializeNodeDef(node->custom_initial_data,
413                                          node->custom_initial_data_size);
414     if (!status.ok()) break;
415     status = node_data.BuildEagerOp(op_data_->eager_context);
416     if (!status.ok()) break;
417   }
418 
419   TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
420 
421   // Given a TfLite tensor index, return the OpNode that produces it,
422   // along with it index into that OpNodes list of outputs.
423   std::map<int, TensorSource> tflite_tensor_sources;
424 
425   // Find out how each tensor is produced. This does not account for
426   // tensors that are not produce by eager ops.
427   for (auto& node_data : op_data_->nodes) {
428     node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
429     for (int i = 0; i < node_data->outputs().Size(); ++i) {
430       int output_index = node_data->outputs().TfLiteIndex(i);
431       tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
432     }
433   }
434 
435   // For each node, resolve the inputs, so we can keep pointers to the nodes
436   // that produces them.
437   for (auto& node_data : op_data_->nodes) {
438     node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
439   }
440   return kTfLiteOk;
441 }
442 
Prepare(TfLiteContext * context,TfLiteNode * node)443 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
444   TF_LITE_ENSURE_MSG(
445       context, op_data_->eager_context != nullptr,
446       "Failed to initialize eager context. This often happens when a CPU "
447       "device has not been registered, presumably because some symbols from "
448       "tensorflow/core:core_cpu_impl were not linked into the binary.");
449 
450   // We will keep track of the number of references to each tensor in the
451   // graph, so we can make them "forwardable" if there is only one reference.
452   std::map<int, int> tensor_ref_count;
453 
454   // Whenever we find a constant tensor, insert it in the buffer map.
455   BufferMap* buffer_map = op_data_->buffer_map;
456   for (auto tensor_index : op_data_->subgraph_inputs) {
457     TfLiteTensor* tensor = &context->tensors[tensor_index];
458     if (IsConstantTensor(tensor)) {
459       if (!buffer_map->HasTensor(tensor_index)) {
460         buffer_map->SetFromTfLite(tensor_index, tensor);
461       }
462     }
463 
464     // Input tensors should never be forwarded so we increment their ref counts
465     // twice: once for this graph and another for the possibility of them being
466     // used by another subgraph, or being an output of the full graph.
467     tensor_ref_count[tensor_index] += 2;
468   }
469 
470   const bool shapes_are_valid =
471       (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
472   if (shapes_are_valid) {
473     TFLITE_LOG(tflite::TFLITE_LOG_INFO,
474                "FlexDelegate: All tensor shapes are consistent.");
475   } else {
476     TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
477                "FlexDelegate: Some tensor shapes are inconsistent.");
478   }
479 
480   // All output tensors are allocated by TensorFlow/Eager, so we
481   // mark them as kTfLiteDynamic.
482   for (auto tensor_index : op_data_->subgraph_outputs) {
483     if (!shapes_are_valid) {
484       SetTensorToDynamic(&context->tensors[tensor_index]);
485     }
486     ++tensor_ref_count[tensor_index];
487   }
488 
489   for (const auto& node_data : op_data_->nodes) {
490     if (node_data->nodedef().op().empty()) {
491       context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
492                            node_data->name().c_str());
493       return kTfLiteError;
494     }
495     TF_LITE_ENSURE(context, node_data->op());
496 
497     for (int i = 0; i < node_data->inputs().Size(); ++i) {
498       ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
499     }
500   }
501 
502   // All tensors that are referenced exactly once are marked as "forwardable",
503   // meaning that we will allow TensorFlow to reuse its buffer as the output of
504   // an op.
505   for (auto& node_data : op_data_->nodes) {
506     for (int i = 0; i < node_data->inputs().Size(); ++i) {
507       bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
508       node_data->mutable_inputs()->SetForwardable(i, f);
509     }
510   }
511 
512   return kTfLiteOk;
513 }
514 
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const515 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
516     TfLiteContext* context) const {
517   for (const auto& node_data : op_data_->nodes) {
518     auto op_name = node_data->name().c_str();
519     // Create an InferenceContext object.
520     auto num_inputs = node_data->inputs().Size();
521     std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
522                                                                 nullptr);
523     InferenceContext c(
524         TF_GRAPH_DEF_VERSION, node_data->nodedef(),
525         node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
526         input_tensors_vector, {},
527         std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
528 
529     // Set input_shapes for ShapeInferenceFn.
530     for (int i = 0; i < num_inputs; ++i) {
531       const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
532       TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
533       // Provide constant input tensors since some op ("RFFT") needs it to
534       // calculate the output shape.
535       if (IsConstantTensor(tfl_tensor)) {
536         input_tensors_vector[i] =
537             op_data_->buffer_map->GetTensorPtr(input_tensor_index);
538       }
539       const auto dims_array = tfl_tensor->dims;
540       std::vector<DimensionHandle> dims(dims_array->size);
541       for (int j = 0; j < dims_array->size; ++j) {
542         dims[j] = c.MakeDim(dims_array->data[j]);
543       }
544       c.SetInput(i, c.MakeShape(dims));
545     }
546     c.set_input_tensors(input_tensors_vector);
547 
548     tensorflow::Status status = c.construction_status();
549     if (!status.ok()) {
550       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
551                  "Shape construction failed for op '%s'", op_name);
552       return kTfLiteError;
553     }
554 
555     // Run ShapeInferenceFn to calculate output shapes.
556     if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
557       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
558                  "No shape inference function exists for op '%s'", op_name);
559       return kTfLiteError;
560     }
561     status = c.Run(node_data->op_reg_data()->shape_inference_fn);
562 
563     // Compare calculated output shapes with node_data->outputs
564     auto num_outputs = node_data->outputs().Size();
565     if (num_outputs != c.num_outputs()) {
566       TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
567                  "Number of output tensors are mismatched for op '%s' %d != %d",
568                  op_name, num_outputs, c.num_outputs());
569       return kTfLiteError;
570     }
571     for (int i = 0; i < num_outputs; ++i) {
572       const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
573       TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
574       // tfl_tensor->dims only has valid information if the given model is
575       // converted by the MLIR converter. Also when ResizeInputTensor() is
576       // called the dims information becomes invalid.
577       const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims);
578       const std::string calculated_shape_string = c.DebugString(c.output(i));
579       // Getting a shape string via c.DebugString() is the easiest way to get
580       // the shape information of the given ShapeHandle for now.
581       // TODO(b/169017408): Find a better approach without using debug string.
582       if (tfl_shape_string != calculated_shape_string) {
583         TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
584                    "op '%s' output%d tensor#%d shape mismatch for  %s != %s",
585                    op_name, i, output_tensor_index, tfl_shape_string.c_str(),
586                    calculated_shape_string.c_str());
587         return kTfLiteError;
588       }
589     }
590   }
591   return kTfLiteOk;
592 }
593 
Eval(TfLiteContext * context,TfLiteNode * node)594 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
595   BufferMap* buffer_map = op_data_->buffer_map;
596 
597   // Insert a tensor in the buffer map for all inputs that are not constant.
598   // Constants were handled in Prepare() already.
599   for (auto tensor_index : op_data_->subgraph_inputs) {
600     TfLiteTensor* tensor = &context->tensors[tensor_index];
601     if (!IsConstantTensor(tensor)) {
602       // If this tensor is part of an earlier TF subgraph we should not add it
603       // to the BufferMap again, because TF already knows about it and its
604       // contents are kept automatically up-to-date.
605       if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
606         buffer_map->SetFromTfLite(tensor_index, tensor);
607       }
608     }
609   }
610 
611   // Execute the TensorFlow Ops sequentially.
612   for (auto& node_data : op_data_->nodes) {
613     TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
614         reinterpret_cast<Profiler*>(context->profiler),
615         node_data->name().c_str(), node_data->index());
616 
617     auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
618     TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
619   }
620 
621   for (auto tensor_index : op_data_->subgraph_outputs) {
622     if (!buffer_map->HasTensor(tensor_index)) {
623       context->ReportError(context, "Cannot write to invalid tensor index %d",
624                            tensor_index);
625       return kTfLiteError;
626     }
627 
628     // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
629     // For dynamic tensors, copy shape and put buffer_handle for the later
630     // CopyFromBufferHandle() call.
631     TfLiteTensor* tensor = &context->tensors[tensor_index];
632     const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
633     if (tensor->allocation_type == kTfLiteDynamic) {
634       TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
635       tensor->buffer_handle = tensor_index;
636       tensor->data_is_stale = true;
637       continue;
638     }
639     // If the tensor isn't dynamic, we can copy data directly to the buffer of
640     // the tensor. Before copying the data, check if the target buffer has
641     // expected size.
642     if (tf_tensor.NumElements() != NumElements(tensor) ||
643         tf_tensor.TotalBytes() != tensor->bytes) {
644       TF_LITE_KERNEL_LOG(
645           context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)",
646           tensor->name, tensor_index, tf_tensor.TotalBytes(),
647           tf_tensor.NumElements(), tensor->bytes, NumElements(tensor));
648       return kTfLiteError;
649     }
650     tensorflow::StringPiece t_data = tf_tensor.tensor_data();
651     memcpy(tensor->data.raw, t_data.data(), t_data.size());
652   }
653 
654   return kTfLiteOk;
655 }
656 
657 }  // namespace flex
658 }  // namespace tflite
659