• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16 
17 #include "flatbuffers/flexbuffers.h"  // TF:flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/c_api_internal.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/delegates/flex/delegate_data.h"
28 #include "tensorflow/lite/delegates/flex/util.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/profiling/profiler.h"
31 #include "tensorflow/lite/string.h"
32 
33 // Note: this is part of TF Lite's Flex delegation code which is to be
34 // completed soon.
35 
36 // This is the TF Lite op that is created by the flex delegate to handle
37 // execution of a supported subgraph. The usual flow is that the delegate
38 // informs the interpreter of supported nodes in a graph, and each supported
39 // subgraph is replaced with one instance of this kernel.
40 //
41 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
42 // the global EagerContext and BufferMap, as well as a list of inputs and
43 // outputs to the subgraph. Those are used to build the OpData, with a list of
44 // TensorFlow Ops that should be executed in order (which we call an OpNode).
45 //
46 // For each node included in the subgraph, we query the interpreter and
47 // retrieve the associated NodeDef, which is then used to configure the
48 // corresponding TensorFlow/Eager Op.
49 
50 namespace tflite {
51 namespace flex {
52 namespace kernel {
53 
54 struct OpNode;
55 
56 // Represents the origin of a given tensor as a reference to the output
57 // of an upstream node.
58 struct TensorSource {
59   OpNode* node;
60   int node_output_index;
61 };
62 
63 // A list of inputs of a given node of the TensorFlow/Eager graph.
64 class OpInputs {
65  public:
OpInputs(const TfLiteIntArray * indexes)66   explicit OpInputs(const TfLiteIntArray* indexes) {
67     for (int index : TfLiteIntArrayView(indexes)) {
68       inputs_.push_back(index);
69     }
70     forwardable_.resize(inputs_.size());
71   }
~OpInputs()72   ~OpInputs() {}
73 
Size() const74   int Size() const { return inputs_.size(); }
75 
TfLiteIndex(int i) const76   int TfLiteIndex(int i) const { return inputs_[i]; }
77 
78   // Given a map relating tensors to the node that originates them, populate a
79   // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)80   void InitializeTensorSources(
81       const std::map<int, TensorSource>& tflite_tensor_sources) {
82     sources_.clear();
83     for (int i : inputs_) {
84       auto it = tflite_tensor_sources.find(i);
85       if (it == tflite_tensor_sources.end()) {
86         sources_.push_back({nullptr, 0});
87       } else {
88         sources_.push_back(it->second);
89       }
90     }
91   }
92 
SetForwardable(int i,bool v)93   void SetForwardable(int i, bool v) { forwardable_[i] = v; }
94 
IsForwardable(int i) const95   bool IsForwardable(int i) const { return forwardable_[i]; }
96 
GetTensorSource(int i) const97   TensorSource GetTensorSource(int i) const { return sources_[i]; }
98 
99  private:
100   std::vector<int> inputs_;
101   std::vector<TensorSource> sources_;
102 
103   // List of tensors that can be used by TF in its forwarding optimization.
104   // Doing so allows an input tensor to be modified and used as the output
105   // tensor. The delegate takes care of not holding any references to tensors
106   // in this list while Eager is executing the corresponding op.
107   std::vector<int> forwardable_;
108 };
109 
110 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
111 // the actual outputs of the EagerOperation.
112 class OpOutputs {
113  public:
OpOutputs(const TfLiteIntArray * indexes)114   explicit OpOutputs(const TfLiteIntArray* indexes) {
115     for (int index : TfLiteIntArrayView(indexes)) {
116       outputs_.push_back(index);
117     }
118     vector_.resize(outputs_.size());
119   }
~OpOutputs()120   ~OpOutputs() { ResetTensorHandles(); }
121 
122   // Stores information about which of the tensors in this class are also
123   // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)124   void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
125     subgraph_outputs_.clear();
126     for (int i : outputs_) {
127       subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
128     }
129   }
130 
131   // Returns true if the tensor given by index 'i' is an output of the entire
132   // subgraph.
IsSubgraphOutput(int i) const133   bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
134 
135   // Returns a handle to a given tensor and, optionally, remove it from the
136   // internal vector.
GetHandle(int i,bool remove)137   tensorflow::TensorHandle* GetHandle(int i, bool remove) {
138     auto* handle = vector_[i];
139     if (!remove) {
140       handle->Ref();
141     } else {
142       // Don't increase the ref-count. Instead, simply take it out of the
143       // vector.
144       vector_[i] = nullptr;
145     }
146     return handle;
147   }
148 
Size() const149   int Size() const { return outputs_.size(); }
150 
TfLiteIndex(int i) const151   int TfLiteIndex(int i) const { return outputs_[i]; }
152 
153   // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()154   void ResetTensorHandles() {
155     for (int i = 0; i < vector_.size(); ++i) {
156       if (vector_[i]) {
157         vector_[i]->Unref();
158         vector_[i] = nullptr;
159       }
160     }
161   }
162 
163   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()164   GetTensorHandles() {
165     return &vector_;
166   }
167 
168  private:
169   std::vector<int> outputs_;
170   std::vector<bool> subgraph_outputs_;
171   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
172 };
173 
174 // A single node within the larger 'op'. Note that this kernel executes many
175 // TensorFlow ops within a single TF Lite op.
176 class OpNode {
177  public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)178   OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
179       : inputs_(inputs), outputs_(outputs) {}
~OpNode()180   ~OpNode() {
181     if (op_) ClearEagerInputs();
182   }
183 
name() const184   const string& name() const { return name_; }
set_name(const string & name)185   void set_name(const string& name) { name_ = name; }
186 
index() const187   int index() const { return index_; }
set_index(int index)188   void set_index(int index) { index_ = index; }
189 
nodedef() const190   const tensorflow::NodeDef& nodedef() const { return nodedef_; }
191 
inputs() const192   const OpInputs& inputs() const { return inputs_; }
mutable_inputs()193   OpInputs* mutable_inputs() { return &inputs_; }
194 
outputs() const195   const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()196   OpOutputs* mutable_outputs() { return &outputs_; }
197 
NumInputs() const198   int NumInputs() const { return inputs_.Size(); }
NumOutputs() const199   int NumOutputs() const { return outputs_.Size(); }
200 
op()201   tensorflow::EagerOperation* op() { return op_.get(); }
202 
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)203   tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
204                                        int custom_initial_data_size) {
205     if (!custom_initial_data) {
206       return tensorflow::errors::Internal(
207           "Cannot convert empty data into a valid NodeDef");
208     }
209     // The flexbuffer contains a vector where the first elements is the
210     // op name and the second is a serialized NodeDef.
211     const flexbuffers::Vector& v =
212         flexbuffers::GetRoot(
213             reinterpret_cast<const uint8_t*>(custom_initial_data),
214             custom_initial_data_size)
215             .AsVector();
216 
217     name_ = v[0].AsString().str();
218     if (!nodedef_.ParseFromString(v[1].AsString().str())) {
219       nodedef_.Clear();
220       return tensorflow::errors::Internal(
221           "Failed to parse data into a valid NodeDef");
222     }
223 
224     // Fill NodeDef with defaults if it's a valid op.
225     const tensorflow::OpRegistrationData* op_reg_data;
226     TF_RETURN_IF_ERROR(
227         tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data));
228     AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_);
229 
230     return tensorflow::Status::OK();
231   }
232 
233   // Build thew new EagerOperation. In case of error, the returned 'op' is
234   // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context)235   tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) {
236     op_.reset();
237 
238     const tensorflow::AttrTypeMap* attr_types;
239     bool is_function = false;
240     TF_RETURN_WITH_CONTEXT_IF_ERROR(
241         tensorflow::AttrTypeMapForOp(name_.c_str(), &attr_types, &is_function),
242         " (while processing attributes of '", name_, "')");
243     if (is_function) {
244       return tensorflow::errors::NotFound(
245           "Operation '", name_,
246           "' is not registered.  (while processing attributes of '", name_,
247           "')");
248     }
249 
250     op_.reset(new tensorflow::EagerOperation(eager_context, name_.c_str(),
251                                              /*is_function=*/false,
252                                              attr_types));
253 
254     op_->MutableAttrs()->NumInputs(inputs_.Size());
255     for (const auto& attr : nodedef_.attr()) {
256       op_->MutableAttrs()->Set(attr.first, attr.second);
257     }
258 
259     // Precalculating a cache key saves about 10% of inference time for very
260     // small models.
261     tensorflow::Device* device = op_->Device();
262     op_->MutableAttrs()->CacheKey(device == nullptr ? "unspecified"
263                                                     : device->name());
264 
265     return tensorflow::Status::OK();
266   }
267 
ClearEagerInputs()268   void ClearEagerInputs() {
269     for (tensorflow::TensorHandle* h : *op_->MutableInputs()) {
270       if (h) h->Unref();
271     }
272     op_->MutableInputs()->clear();
273   }
274 
BuildEagerInputs(const BufferMap * buffer_map)275   tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
276     for (int i = 0; i < inputs_.Size(); ++i) {
277       int input_index = inputs_.TfLiteIndex(i);
278       TensorSource s = inputs_.GetTensorSource(i);
279       if (!s.node) {
280         // This input is not produced by this Eager subgraph (it could be a TF
281         // Lite native buffer, or could be produced by a separater subgraph). We
282         // need to fetch it from the delegate's buffer_map.
283         if (!buffer_map->HasTensor(input_index)) {
284           return tensorflow::errors::Internal(
285               "Cannot read from invalid tensor index ", input_index);
286         }
287         auto* handle = new tensorflow::TensorHandle(
288             buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
289         op_->MutableInputs()->push_back(handle);
290       } else {
291         // If this is a forwardable tensor, we will remove it from the previous
292         // op's list, giving TF the opportunity to reuse its buffer.
293         bool unref_handle = inputs_.IsForwardable(i);
294         auto* handle =
295             s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
296         op_->MutableInputs()->push_back(handle);
297       }
298     }
299     return tensorflow::Status::OK();
300   }
301 
PersistEagerOutputs(BufferMap * buffer_map)302   tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
303     auto* handles = outputs_.GetTensorHandles();
304     for (int i = 0; i < outputs_.Size(); ++i) {
305       if (outputs_.IsSubgraphOutput(i)) {
306         const tensorflow::Tensor* tensor = nullptr;
307         TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
308         buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
309       }
310     }
311     return tensorflow::Status::OK();
312   }
313 
314  private:
315   OpNode(const OpNode&) = delete;
316   OpNode& operator=(const OpNode&) = delete;
317 
318   // The name of the TensorFlow op to execute.
319   string name_;
320   // Index of this node into TF Lite's operator list.
321   int index_;
322   // The corresponding NodeDef, containing the attributes for the op.
323   tensorflow::NodeDef nodedef_;
324   // List of inputs, as TF Lite tensor indices.
325   OpInputs inputs_;
326   // List of outputs, as TF Lite tensor indices.
327   OpOutputs outputs_;
328 
329   std::unique_ptr<tensorflow::EagerOperation> op_;
330 };
331 
332 // Executes the TensorFlow op given by 'op_name', with the attributes specified
333 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)334 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
335                                  OpNode* node_data) {
336   TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
337                                   " (while executing '", node_data->name(),
338                                   "' via Eager)");
339 
340   node_data->mutable_outputs()->ResetTensorHandles();
341   int num_retvals = node_data->NumOutputs();
342   TF_RETURN_WITH_CONTEXT_IF_ERROR(
343       EagerExecute(node_data->op(),
344                    node_data->mutable_outputs()->GetTensorHandles(),
345                    &num_retvals),
346       " (while executing '", node_data->name(), "' via Eager)");
347 
348   if (num_retvals != node_data->NumOutputs()) {
349     return tensorflow::errors::Internal(
350         "Unexpected number of outputs from EagerExecute");
351   }
352 
353   TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
354 
355   node_data->ClearEagerInputs();
356 
357   return tensorflow::Status::OK();
358 }
359 
360 // The larger 'op', which contains all the nodes in a supported subgraph.
361 struct OpData {
362   tensorflow::EagerContext* eager_context;
363   BufferMap* buffer_map;
364   std::vector<std::unique_ptr<OpNode>> nodes;
365   std::vector<int> subgraph_inputs;
366   std::vector<int> subgraph_outputs;
367 };
368 
Init(TfLiteContext * context,const char * buffer,size_t length)369 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
370   auto* op_data = new OpData;
371 
372   const TfLiteDelegateParams* params =
373       reinterpret_cast<const TfLiteDelegateParams*>(buffer);
374   CHECK(params);
375   CHECK(params->delegate);
376   CHECK(params->delegate->data_);
377   op_data->eager_context =
378       reinterpret_cast<DelegateData*>(params->delegate->data_)
379           ->GetEagerContext();
380   op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
381                             ->GetBufferMap(context);
382 
383   CHECK(params->output_tensors);
384   std::set<int> output_set;
385   for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
386     op_data->subgraph_outputs.push_back(tensor_index);
387     output_set.insert(tensor_index);
388   }
389 
390   CHECK(params->input_tensors);
391   for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
392     op_data->subgraph_inputs.push_back(tensor_index);
393   }
394 
395   op_data->nodes.reserve(params->nodes_to_replace->size);
396 
397   CHECK(params->nodes_to_replace);
398   tensorflow::Status status;
399   for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
400     TfLiteNode* node;
401     TfLiteRegistration* reg;
402     context->GetNodeAndRegistration(context, node_index, &node, &reg);
403 
404     op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
405     OpNode& node_data = *op_data->nodes.back();
406 
407     node_data.set_index(node_index);
408     node_data.set_name("");
409 
410     status = node_data.InitializeNodeDef(node->custom_initial_data,
411                                          node->custom_initial_data_size);
412     if (!status.ok()) break;
413     status = node_data.BuildEagerOp(op_data->eager_context);
414     if (!status.ok()) break;
415   }
416 
417   if (ConvertStatus(context, status) != kTfLiteOk) {
418     // We can't return an error from this function but ConvertStatus will
419     // report them and we will stop processing in Prepare() if anything went
420     // wrong.
421     return op_data;
422   }
423 
424   // Given a TfLite tensor index, return the OpNode that produces it,
425   // along with it index into that OpNodes list of outputs.
426   std::map<int, TensorSource> tflite_tensor_sources;
427 
428   // Find out how each tensor is produced. This does not account for
429   // tensors that are not produce by eager ops.
430   for (auto& node_data : op_data->nodes) {
431     node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
432     for (int i = 0; i < node_data->outputs().Size(); ++i) {
433       int output_index = node_data->outputs().TfLiteIndex(i);
434       tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
435     }
436   }
437 
438   // For each node, resolve the inputs, so we can keep pointers to the nodes
439   // that produces them.
440   for (auto& node_data : op_data->nodes) {
441     node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
442   }
443 
444   return op_data;
445 }
446 
Free(TfLiteContext * context,void * buffer)447 void Free(TfLiteContext* context, void* buffer) {
448   delete reinterpret_cast<OpData*>(buffer);
449 }
450 
Prepare(TfLiteContext * context,TfLiteNode * node)451 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
452   const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
453   TF_LITE_ENSURE_MSG(
454       context, op_data->eager_context != nullptr,
455       "Failed to initialize eager context. This often happens when a CPU "
456       "device has not been registered, presumably because some symbols from "
457       "tensorflow/core:core_cpu_impl were not linked into the binary.");
458 
459   // We will keep track of the number of references to each tensor in the
460   // graph, so we can make them "forwardable" if there is only one reference.
461   std::map<int, int> tensor_ref_count;
462 
463   // Whenever we find a constant tensor, insert it in the buffer map.
464   BufferMap* buffer_map = op_data->buffer_map;
465   for (auto tensor_index : op_data->subgraph_inputs) {
466     TfLiteTensor* tensor = &context->tensors[tensor_index];
467     if (IsConstantTensor(tensor)) {
468       if (!buffer_map->HasTensor(tensor_index)) {
469         buffer_map->SetFromTfLite(tensor_index, tensor);
470       }
471     }
472 
473     // Input tensors should never be forwarded so we increment their ref counts
474     // twice: once for this graph and another for the possibility of them being
475     // used by another subgraph, or being an output of the full graph.
476     tensor_ref_count[tensor_index] += 2;
477   }
478 
479   // All output tensors are allocated by TensorFlow/Eager, so we
480   // mark them as kTfLiteDynamic.
481   for (auto tensor_index : op_data->subgraph_outputs) {
482     SetTensorToDynamic(&context->tensors[tensor_index]);
483     ++tensor_ref_count[tensor_index];
484   }
485 
486   for (const auto& node_data : op_data->nodes) {
487     if (node_data->nodedef().op().empty()) {
488       context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
489                            node_data->name().c_str());
490       return kTfLiteError;
491     }
492     TF_LITE_ENSURE(context, node_data->op());
493 
494     for (int i = 0; i < node_data->inputs().Size(); ++i) {
495       ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
496     }
497   }
498 
499   // All tensors that are referenced exactly once are marked as "forwardable",
500   // meaning that we will allow TensorFlow to reuse its buffer as the output of
501   // an op.
502   for (auto& node_data : op_data->nodes) {
503     for (int i = 0; i < node_data->inputs().Size(); ++i) {
504       bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
505       node_data->mutable_inputs()->SetForwardable(i, f);
506     }
507   }
508 
509   return kTfLiteOk;
510 }
511 
Eval(TfLiteContext * context,TfLiteNode * node)512 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
513   auto* op_data = reinterpret_cast<OpData*>(node->user_data);
514   BufferMap* buffer_map = op_data->buffer_map;
515 
516   // Insert a tensor in the buffer map for all inputs that are not constant.
517   // Constants were handled in Prepare() already.
518   for (auto tensor_index : op_data->subgraph_inputs) {
519     TfLiteTensor* tensor = &context->tensors[tensor_index];
520     if (!IsConstantTensor(tensor)) {
521       // If this tensor is part of an earlier TF subgraph we should not add it
522       // to the BufferMap again, because TF already knows about it and its
523       // contents are kept automatically up-to-date.
524       if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
525         buffer_map->SetFromTfLite(tensor_index, tensor);
526       }
527     }
528   }
529 
530   // Execute the TensorFlow Ops sequentially.
531   for (auto& node_data : op_data->nodes) {
532     SCOPED_TAGGED_OPERATOR_PROFILE(
533         reinterpret_cast<profiling::Profiler*>(context->profiler),
534         node_data->name().c_str(), node_data->index());
535 
536     auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
537     TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
538   }
539 
540   for (auto tensor_index : op_data->subgraph_outputs) {
541     if (!buffer_map->HasTensor(tensor_index)) {
542       context->ReportError(context, "Cannot write to invalid tensor index %d",
543                            tensor_index);
544       return kTfLiteError;
545     }
546 
547     TfLiteTensor* tensor = &context->tensors[tensor_index];
548     TF_LITE_ENSURE_OK(
549         context,
550         CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
551     tensor->buffer_handle = tensor_index;
552     tensor->data_is_stale = true;
553   }
554 
555   return kTfLiteOk;
556 }
557 
558 }  // namespace kernel
559 
GetKernel()560 TfLiteRegistration GetKernel() {
561   TfLiteRegistration registration{&kernel::Init,    &kernel::Free,
562                                   &kernel::Prepare, &kernel::Eval,
563                                   nullptr,          kTfLiteBuiltinDelegate};
564   return registration;
565 }
566 
567 }  // namespace flex
568 }  // namespace tflite
569