1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16
17 #include "flatbuffers/flexbuffers.h" // TF:flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/c_api_internal.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/delegates/flex/delegate_data.h"
28 #include "tensorflow/lite/delegates/flex/util.h"
29 #include "tensorflow/lite/kernels/kernel_util.h"
30 #include "tensorflow/lite/profiling/profiler.h"
31 #include "tensorflow/lite/string.h"
32
33 // Note: this is part of TF Lite's Flex delegation code which is to be
34 // completed soon.
35
36 // This is the TF Lite op that is created by the flex delegate to handle
37 // execution of a supported subgraph. The usual flow is that the delegate
38 // informs the interpreter of supported nodes in a graph, and each supported
39 // subgraph is replaced with one instance of this kernel.
40 //
41 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
42 // the global EagerContext and BufferMap, as well as a list of inputs and
43 // outputs to the subgraph. Those are used to build the OpData, with a list of
44 // TensorFlow Ops that should be executed in order (which we call an OpNode).
45 //
46 // For each node included in the subgraph, we query the interpreter and
47 // retrieve the associated NodeDef, which is then used to configure the
48 // corresponding TensorFlow/Eager Op.
49
50 namespace tflite {
51 namespace flex {
52 namespace kernel {
53
54 struct OpNode;
55
56 // Represents the origin of a given tensor as a reference to the output
57 // of an upstream node.
58 struct TensorSource {
59 OpNode* node;
60 int node_output_index;
61 };
62
63 // A list of inputs of a given node of the TensorFlow/Eager graph.
64 class OpInputs {
65 public:
OpInputs(const TfLiteIntArray * indexes)66 explicit OpInputs(const TfLiteIntArray* indexes) {
67 for (int index : TfLiteIntArrayView(indexes)) {
68 inputs_.push_back(index);
69 }
70 forwardable_.resize(inputs_.size());
71 }
~OpInputs()72 ~OpInputs() {}
73
Size() const74 int Size() const { return inputs_.size(); }
75
TfLiteIndex(int i) const76 int TfLiteIndex(int i) const { return inputs_[i]; }
77
78 // Given a map relating tensors to the node that originates them, populate a
79 // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)80 void InitializeTensorSources(
81 const std::map<int, TensorSource>& tflite_tensor_sources) {
82 sources_.clear();
83 for (int i : inputs_) {
84 auto it = tflite_tensor_sources.find(i);
85 if (it == tflite_tensor_sources.end()) {
86 sources_.push_back({nullptr, 0});
87 } else {
88 sources_.push_back(it->second);
89 }
90 }
91 }
92
SetForwardable(int i,bool v)93 void SetForwardable(int i, bool v) { forwardable_[i] = v; }
94
IsForwardable(int i) const95 bool IsForwardable(int i) const { return forwardable_[i]; }
96
GetTensorSource(int i) const97 TensorSource GetTensorSource(int i) const { return sources_[i]; }
98
99 private:
100 std::vector<int> inputs_;
101 std::vector<TensorSource> sources_;
102
103 // List of tensors that can be used by TF in its forwarding optimization.
104 // Doing so allows an input tensor to be modified and used as the output
105 // tensor. The delegate takes care of not holding any references to tensors
106 // in this list while Eager is executing the corresponding op.
107 std::vector<int> forwardable_;
108 };
109
110 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
111 // the actual outputs of the EagerOperation.
112 class OpOutputs {
113 public:
OpOutputs(const TfLiteIntArray * indexes)114 explicit OpOutputs(const TfLiteIntArray* indexes) {
115 for (int index : TfLiteIntArrayView(indexes)) {
116 outputs_.push_back(index);
117 }
118 vector_.resize(outputs_.size());
119 }
~OpOutputs()120 ~OpOutputs() { ResetTensorHandles(); }
121
122 // Stores information about which of the tensors in this class are also
123 // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)124 void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
125 subgraph_outputs_.clear();
126 for (int i : outputs_) {
127 subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
128 }
129 }
130
131 // Returns true if the tensor given by index 'i' is an output of the entire
132 // subgraph.
IsSubgraphOutput(int i) const133 bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
134
135 // Returns a handle to a given tensor and, optionally, remove it from the
136 // internal vector.
GetHandle(int i,bool remove)137 tensorflow::TensorHandle* GetHandle(int i, bool remove) {
138 auto* handle = vector_[i];
139 if (!remove) {
140 handle->Ref();
141 } else {
142 // Don't increase the ref-count. Instead, simply take it out of the
143 // vector.
144 vector_[i] = nullptr;
145 }
146 return handle;
147 }
148
Size() const149 int Size() const { return outputs_.size(); }
150
TfLiteIndex(int i) const151 int TfLiteIndex(int i) const { return outputs_[i]; }
152
153 // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()154 void ResetTensorHandles() {
155 for (int i = 0; i < vector_.size(); ++i) {
156 if (vector_[i]) {
157 vector_[i]->Unref();
158 vector_[i] = nullptr;
159 }
160 }
161 }
162
163 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()164 GetTensorHandles() {
165 return &vector_;
166 }
167
168 private:
169 std::vector<int> outputs_;
170 std::vector<bool> subgraph_outputs_;
171 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
172 };
173
174 // A single node within the larger 'op'. Note that this kernel executes many
175 // TensorFlow ops within a single TF Lite op.
176 class OpNode {
177 public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)178 OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
179 : inputs_(inputs), outputs_(outputs) {}
~OpNode()180 ~OpNode() {
181 if (op_) ClearEagerInputs();
182 }
183
name() const184 const string& name() const { return name_; }
set_name(const string & name)185 void set_name(const string& name) { name_ = name; }
186
index() const187 int index() const { return index_; }
set_index(int index)188 void set_index(int index) { index_ = index; }
189
nodedef() const190 const tensorflow::NodeDef& nodedef() const { return nodedef_; }
191
inputs() const192 const OpInputs& inputs() const { return inputs_; }
mutable_inputs()193 OpInputs* mutable_inputs() { return &inputs_; }
194
outputs() const195 const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()196 OpOutputs* mutable_outputs() { return &outputs_; }
197
NumInputs() const198 int NumInputs() const { return inputs_.Size(); }
NumOutputs() const199 int NumOutputs() const { return outputs_.Size(); }
200
op()201 tensorflow::EagerOperation* op() { return op_.get(); }
202
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)203 tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
204 int custom_initial_data_size) {
205 if (!custom_initial_data) {
206 return tensorflow::errors::Internal(
207 "Cannot convert empty data into a valid NodeDef");
208 }
209 // The flexbuffer contains a vector where the first elements is the
210 // op name and the second is a serialized NodeDef.
211 const flexbuffers::Vector& v =
212 flexbuffers::GetRoot(
213 reinterpret_cast<const uint8_t*>(custom_initial_data),
214 custom_initial_data_size)
215 .AsVector();
216
217 name_ = v[0].AsString().str();
218 if (!nodedef_.ParseFromString(v[1].AsString().str())) {
219 nodedef_.Clear();
220 return tensorflow::errors::Internal(
221 "Failed to parse data into a valid NodeDef");
222 }
223
224 // Fill NodeDef with defaults if it's a valid op.
225 const tensorflow::OpRegistrationData* op_reg_data;
226 TF_RETURN_IF_ERROR(
227 tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data));
228 AddDefaultsToNodeDef(op_reg_data->op_def, &nodedef_);
229
230 return tensorflow::Status::OK();
231 }
232
233 // Build thew new EagerOperation. In case of error, the returned 'op' is
234 // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context)235 tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) {
236 op_.reset();
237
238 const tensorflow::AttrTypeMap* attr_types;
239 bool is_function = false;
240 TF_RETURN_WITH_CONTEXT_IF_ERROR(
241 tensorflow::AttrTypeMapForOp(name_.c_str(), &attr_types, &is_function),
242 " (while processing attributes of '", name_, "')");
243 if (is_function) {
244 return tensorflow::errors::NotFound(
245 "Operation '", name_,
246 "' is not registered. (while processing attributes of '", name_,
247 "')");
248 }
249
250 op_.reset(new tensorflow::EagerOperation(eager_context, name_.c_str(),
251 /*is_function=*/false,
252 attr_types));
253
254 op_->MutableAttrs()->NumInputs(inputs_.Size());
255 for (const auto& attr : nodedef_.attr()) {
256 op_->MutableAttrs()->Set(attr.first, attr.second);
257 }
258
259 // Precalculating a cache key saves about 10% of inference time for very
260 // small models.
261 tensorflow::Device* device = op_->Device();
262 op_->MutableAttrs()->CacheKey(device == nullptr ? "unspecified"
263 : device->name());
264
265 return tensorflow::Status::OK();
266 }
267
ClearEagerInputs()268 void ClearEagerInputs() {
269 for (tensorflow::TensorHandle* h : *op_->MutableInputs()) {
270 if (h) h->Unref();
271 }
272 op_->MutableInputs()->clear();
273 }
274
BuildEagerInputs(const BufferMap * buffer_map)275 tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
276 for (int i = 0; i < inputs_.Size(); ++i) {
277 int input_index = inputs_.TfLiteIndex(i);
278 TensorSource s = inputs_.GetTensorSource(i);
279 if (!s.node) {
280 // This input is not produced by this Eager subgraph (it could be a TF
281 // Lite native buffer, or could be produced by a separater subgraph). We
282 // need to fetch it from the delegate's buffer_map.
283 if (!buffer_map->HasTensor(input_index)) {
284 return tensorflow::errors::Internal(
285 "Cannot read from invalid tensor index ", input_index);
286 }
287 auto* handle = new tensorflow::TensorHandle(
288 buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
289 op_->MutableInputs()->push_back(handle);
290 } else {
291 // If this is a forwardable tensor, we will remove it from the previous
292 // op's list, giving TF the opportunity to reuse its buffer.
293 bool unref_handle = inputs_.IsForwardable(i);
294 auto* handle =
295 s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
296 op_->MutableInputs()->push_back(handle);
297 }
298 }
299 return tensorflow::Status::OK();
300 }
301
PersistEagerOutputs(BufferMap * buffer_map)302 tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
303 auto* handles = outputs_.GetTensorHandles();
304 for (int i = 0; i < outputs_.Size(); ++i) {
305 if (outputs_.IsSubgraphOutput(i)) {
306 const tensorflow::Tensor* tensor = nullptr;
307 TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
308 buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
309 }
310 }
311 return tensorflow::Status::OK();
312 }
313
314 private:
315 OpNode(const OpNode&) = delete;
316 OpNode& operator=(const OpNode&) = delete;
317
318 // The name of the TensorFlow op to execute.
319 string name_;
320 // Index of this node into TF Lite's operator list.
321 int index_;
322 // The corresponding NodeDef, containing the attributes for the op.
323 tensorflow::NodeDef nodedef_;
324 // List of inputs, as TF Lite tensor indices.
325 OpInputs inputs_;
326 // List of outputs, as TF Lite tensor indices.
327 OpOutputs outputs_;
328
329 std::unique_ptr<tensorflow::EagerOperation> op_;
330 };
331
332 // Executes the TensorFlow op given by 'op_name', with the attributes specified
333 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)334 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
335 OpNode* node_data) {
336 TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
337 " (while executing '", node_data->name(),
338 "' via Eager)");
339
340 node_data->mutable_outputs()->ResetTensorHandles();
341 int num_retvals = node_data->NumOutputs();
342 TF_RETURN_WITH_CONTEXT_IF_ERROR(
343 EagerExecute(node_data->op(),
344 node_data->mutable_outputs()->GetTensorHandles(),
345 &num_retvals),
346 " (while executing '", node_data->name(), "' via Eager)");
347
348 if (num_retvals != node_data->NumOutputs()) {
349 return tensorflow::errors::Internal(
350 "Unexpected number of outputs from EagerExecute");
351 }
352
353 TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
354
355 node_data->ClearEagerInputs();
356
357 return tensorflow::Status::OK();
358 }
359
360 // The larger 'op', which contains all the nodes in a supported subgraph.
361 struct OpData {
362 tensorflow::EagerContext* eager_context;
363 BufferMap* buffer_map;
364 std::vector<std::unique_ptr<OpNode>> nodes;
365 std::vector<int> subgraph_inputs;
366 std::vector<int> subgraph_outputs;
367 };
368
Init(TfLiteContext * context,const char * buffer,size_t length)369 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
370 auto* op_data = new OpData;
371
372 const TfLiteDelegateParams* params =
373 reinterpret_cast<const TfLiteDelegateParams*>(buffer);
374 CHECK(params);
375 CHECK(params->delegate);
376 CHECK(params->delegate->data_);
377 op_data->eager_context =
378 reinterpret_cast<DelegateData*>(params->delegate->data_)
379 ->GetEagerContext();
380 op_data->buffer_map = reinterpret_cast<DelegateData*>(params->delegate->data_)
381 ->GetBufferMap(context);
382
383 CHECK(params->output_tensors);
384 std::set<int> output_set;
385 for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
386 op_data->subgraph_outputs.push_back(tensor_index);
387 output_set.insert(tensor_index);
388 }
389
390 CHECK(params->input_tensors);
391 for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
392 op_data->subgraph_inputs.push_back(tensor_index);
393 }
394
395 op_data->nodes.reserve(params->nodes_to_replace->size);
396
397 CHECK(params->nodes_to_replace);
398 tensorflow::Status status;
399 for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
400 TfLiteNode* node;
401 TfLiteRegistration* reg;
402 context->GetNodeAndRegistration(context, node_index, &node, ®);
403
404 op_data->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
405 OpNode& node_data = *op_data->nodes.back();
406
407 node_data.set_index(node_index);
408 node_data.set_name("");
409
410 status = node_data.InitializeNodeDef(node->custom_initial_data,
411 node->custom_initial_data_size);
412 if (!status.ok()) break;
413 status = node_data.BuildEagerOp(op_data->eager_context);
414 if (!status.ok()) break;
415 }
416
417 if (ConvertStatus(context, status) != kTfLiteOk) {
418 // We can't return an error from this function but ConvertStatus will
419 // report them and we will stop processing in Prepare() if anything went
420 // wrong.
421 return op_data;
422 }
423
424 // Given a TfLite tensor index, return the OpNode that produces it,
425 // along with it index into that OpNodes list of outputs.
426 std::map<int, TensorSource> tflite_tensor_sources;
427
428 // Find out how each tensor is produced. This does not account for
429 // tensors that are not produce by eager ops.
430 for (auto& node_data : op_data->nodes) {
431 node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
432 for (int i = 0; i < node_data->outputs().Size(); ++i) {
433 int output_index = node_data->outputs().TfLiteIndex(i);
434 tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
435 }
436 }
437
438 // For each node, resolve the inputs, so we can keep pointers to the nodes
439 // that produces them.
440 for (auto& node_data : op_data->nodes) {
441 node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
442 }
443
444 return op_data;
445 }
446
Free(TfLiteContext * context,void * buffer)447 void Free(TfLiteContext* context, void* buffer) {
448 delete reinterpret_cast<OpData*>(buffer);
449 }
450
Prepare(TfLiteContext * context,TfLiteNode * node)451 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
452 const auto* op_data = reinterpret_cast<OpData*>(node->user_data);
453 TF_LITE_ENSURE_MSG(
454 context, op_data->eager_context != nullptr,
455 "Failed to initialize eager context. This often happens when a CPU "
456 "device has not been registered, presumably because some symbols from "
457 "tensorflow/core:core_cpu_impl were not linked into the binary.");
458
459 // We will keep track of the number of references to each tensor in the
460 // graph, so we can make them "forwardable" if there is only one reference.
461 std::map<int, int> tensor_ref_count;
462
463 // Whenever we find a constant tensor, insert it in the buffer map.
464 BufferMap* buffer_map = op_data->buffer_map;
465 for (auto tensor_index : op_data->subgraph_inputs) {
466 TfLiteTensor* tensor = &context->tensors[tensor_index];
467 if (IsConstantTensor(tensor)) {
468 if (!buffer_map->HasTensor(tensor_index)) {
469 buffer_map->SetFromTfLite(tensor_index, tensor);
470 }
471 }
472
473 // Input tensors should never be forwarded so we increment their ref counts
474 // twice: once for this graph and another for the possibility of them being
475 // used by another subgraph, or being an output of the full graph.
476 tensor_ref_count[tensor_index] += 2;
477 }
478
479 // All output tensors are allocated by TensorFlow/Eager, so we
480 // mark them as kTfLiteDynamic.
481 for (auto tensor_index : op_data->subgraph_outputs) {
482 SetTensorToDynamic(&context->tensors[tensor_index]);
483 ++tensor_ref_count[tensor_index];
484 }
485
486 for (const auto& node_data : op_data->nodes) {
487 if (node_data->nodedef().op().empty()) {
488 context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
489 node_data->name().c_str());
490 return kTfLiteError;
491 }
492 TF_LITE_ENSURE(context, node_data->op());
493
494 for (int i = 0; i < node_data->inputs().Size(); ++i) {
495 ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
496 }
497 }
498
499 // All tensors that are referenced exactly once are marked as "forwardable",
500 // meaning that we will allow TensorFlow to reuse its buffer as the output of
501 // an op.
502 for (auto& node_data : op_data->nodes) {
503 for (int i = 0; i < node_data->inputs().Size(); ++i) {
504 bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
505 node_data->mutable_inputs()->SetForwardable(i, f);
506 }
507 }
508
509 return kTfLiteOk;
510 }
511
Eval(TfLiteContext * context,TfLiteNode * node)512 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
513 auto* op_data = reinterpret_cast<OpData*>(node->user_data);
514 BufferMap* buffer_map = op_data->buffer_map;
515
516 // Insert a tensor in the buffer map for all inputs that are not constant.
517 // Constants were handled in Prepare() already.
518 for (auto tensor_index : op_data->subgraph_inputs) {
519 TfLiteTensor* tensor = &context->tensors[tensor_index];
520 if (!IsConstantTensor(tensor)) {
521 // If this tensor is part of an earlier TF subgraph we should not add it
522 // to the BufferMap again, because TF already knows about it and its
523 // contents are kept automatically up-to-date.
524 if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
525 buffer_map->SetFromTfLite(tensor_index, tensor);
526 }
527 }
528 }
529
530 // Execute the TensorFlow Ops sequentially.
531 for (auto& node_data : op_data->nodes) {
532 SCOPED_TAGGED_OPERATOR_PROFILE(
533 reinterpret_cast<profiling::Profiler*>(context->profiler),
534 node_data->name().c_str(), node_data->index());
535
536 auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
537 TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
538 }
539
540 for (auto tensor_index : op_data->subgraph_outputs) {
541 if (!buffer_map->HasTensor(tensor_index)) {
542 context->ReportError(context, "Cannot write to invalid tensor index %d",
543 tensor_index);
544 return kTfLiteError;
545 }
546
547 TfLiteTensor* tensor = &context->tensors[tensor_index];
548 TF_LITE_ENSURE_OK(
549 context,
550 CopyShapeAndType(context, buffer_map->GetTensor(tensor_index), tensor));
551 tensor->buffer_handle = tensor_index;
552 tensor->data_is_stale = true;
553 }
554
555 return kTfLiteOk;
556 }
557
558 } // namespace kernel
559
GetKernel()560 TfLiteRegistration GetKernel() {
561 TfLiteRegistration registration{&kernel::Init, &kernel::Free,
562 &kernel::Prepare, &kernel::Eval,
563 nullptr, kTfLiteBuiltinDelegate};
564 return registration;
565 }
566
567 } // namespace flex
568 } // namespace tflite
569