1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16
17 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/core/api/profiler.h"
28 #include "tensorflow/lite/delegates/flex/delegate.h"
29 #include "tensorflow/lite/delegates/flex/delegate_data.h"
30 #include "tensorflow/lite/delegates/flex/util.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/minimal_logging.h"
33 #include "tensorflow/lite/string_type.h"
34
35 // Note: this is part of TF Lite's Flex delegation code which is to be
36 // completed soon.
37
38 // This is the TF Lite op that is created by the flex delegate to handle
39 // execution of a supported subgraph. The usual flow is that the delegate
40 // informs the interpreter of supported nodes in a graph, and each supported
41 // subgraph is replaced with one instance of this kernel.
42 //
43 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
44 // the global EagerContext and BufferMap, as well as a list of inputs and
45 // outputs to the subgraph. Those are used to build the OpData, with a list of
46 // TensorFlow Ops that should be executed in order (which we call an OpNode).
47 //
48 // For each node included in the subgraph, we query the interpreter and
49 // retrieve the associated NodeDef, which is then used to configure the
50 // corresponding TensorFlow/Eager Op.
51
52 using tensorflow::shape_inference::DimensionHandle;
53 using tensorflow::shape_inference::InferenceContext;
54 using tensorflow::shape_inference::ShapeAndType;
55 using tensorflow::shape_inference::ShapeHandle;
56
GetDimsDebugString(const TfLiteIntArray * dims)57 const std::string GetDimsDebugString(const TfLiteIntArray* dims) {
58 return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","),
59 "]");
60 }
61
62 namespace tflite {
63 namespace flex {
64
65 struct OpNode;
66
67 // Represents the origin of a given tensor as a reference to the output
68 // of an upstream node.
69 struct TensorSource {
70 OpNode* node;
71 int node_output_index;
72 };
73
74 // A list of inputs of a given node of the TensorFlow/Eager graph.
75 class OpInputs {
76 public:
OpInputs(const TfLiteIntArray * indexes)77 explicit OpInputs(const TfLiteIntArray* indexes) {
78 for (int index : TfLiteIntArrayView(indexes)) {
79 inputs_.push_back(index);
80 }
81 forwardable_.resize(inputs_.size());
82 }
~OpInputs()83 ~OpInputs() {}
84
Size() const85 int Size() const { return inputs_.size(); }
86
TfLiteIndex(int i) const87 int TfLiteIndex(int i) const { return inputs_[i]; }
88
89 // Given a map relating tensors to the node that originates them, populate a
90 // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)91 void InitializeTensorSources(
92 const std::map<int, TensorSource>& tflite_tensor_sources) {
93 sources_.clear();
94 for (int i : inputs_) {
95 auto it = tflite_tensor_sources.find(i);
96 if (it == tflite_tensor_sources.end()) {
97 sources_.push_back({nullptr, 0});
98 } else {
99 sources_.push_back(it->second);
100 }
101 }
102 }
103
SetForwardable(int i,bool v)104 void SetForwardable(int i, bool v) { forwardable_[i] = v; }
105
IsForwardable(int i) const106 bool IsForwardable(int i) const { return forwardable_[i]; }
107
GetTensorSource(int i) const108 TensorSource GetTensorSource(int i) const { return sources_[i]; }
109
110 private:
111 std::vector<int> inputs_;
112 std::vector<TensorSource> sources_;
113
114 // List of tensors that can be used by TF in its forwarding optimization.
115 // Doing so allows an input tensor to be modified and used as the output
116 // tensor. The delegate takes care of not holding any references to tensors
117 // in this list while Eager is executing the corresponding op.
118 std::vector<int> forwardable_;
119 };
120
121 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
122 // the actual outputs of the EagerOperation.
123 class OpOutputs {
124 public:
OpOutputs(const TfLiteIntArray * indexes)125 explicit OpOutputs(const TfLiteIntArray* indexes) {
126 for (int index : TfLiteIntArrayView(indexes)) {
127 outputs_.push_back(index);
128 }
129 vector_.resize(outputs_.size());
130 }
~OpOutputs()131 ~OpOutputs() { ResetTensorHandles(); }
132
133 // Stores information about which of the tensors in this class are also
134 // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)135 void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
136 subgraph_outputs_.clear();
137 for (int i : outputs_) {
138 subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
139 }
140 }
141
142 // Returns true if the tensor given by index 'i' is an output of the entire
143 // subgraph.
IsSubgraphOutput(int i) const144 bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
145
146 // Returns a handle to a given tensor and, optionally, remove it from the
147 // internal vector.
GetHandle(int i,bool remove)148 tensorflow::TensorHandle* GetHandle(int i, bool remove) {
149 auto* handle = vector_[i];
150 if (!remove) {
151 handle->Ref();
152 } else {
153 // Don't increase the ref-count. Instead, simply take it out of the
154 // vector.
155 vector_[i] = nullptr;
156 }
157 return handle;
158 }
159
Size() const160 int Size() const { return outputs_.size(); }
161
TfLiteIndex(int i) const162 int TfLiteIndex(int i) const { return outputs_[i]; }
163
164 // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()165 void ResetTensorHandles() {
166 for (int i = 0; i < vector_.size(); ++i) {
167 if (vector_[i]) {
168 vector_[i]->Unref();
169 vector_[i] = nullptr;
170 }
171 }
172 }
173
174 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()175 GetTensorHandles() {
176 return &vector_;
177 }
178
179 private:
180 std::vector<int> outputs_;
181 std::vector<bool> subgraph_outputs_;
182 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
183 };
184
185 // A single node within the larger 'op'. Note that this kernel executes many
186 // TensorFlow ops within a single TF Lite op.
187 class OpNode {
188 public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)189 OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
190 : inputs_(inputs), outputs_(outputs) {}
~OpNode()191 ~OpNode() {
192 if (op_) ClearEagerInputs();
193 }
194
name() const195 const string& name() const { return name_; }
set_name(const string & name)196 void set_name(const string& name) { name_ = name; }
197
index() const198 int index() const { return index_; }
set_index(int index)199 void set_index(int index) { index_ = index; }
200
nodedef() const201 const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const202 const tensorflow::OpRegistrationData* op_reg_data() const {
203 return op_reg_data_;
204 }
205
inputs() const206 const OpInputs& inputs() const { return inputs_; }
mutable_inputs()207 OpInputs* mutable_inputs() { return &inputs_; }
208
outputs() const209 const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()210 OpOutputs* mutable_outputs() { return &outputs_; }
211
NumInputs() const212 int NumInputs() const { return inputs_.Size(); }
NumOutputs() const213 int NumOutputs() const { return outputs_.Size(); }
214
op()215 tensorflow::EagerOperation* op() { return op_.get(); }
216
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)217 tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
218 int custom_initial_data_size) {
219 if (!custom_initial_data) {
220 return tensorflow::errors::Internal(
221 "Cannot convert empty data into a valid NodeDef");
222 }
223 // The flexbuffer contains a vector where the first elements is the
224 // op name and the second is a serialized NodeDef.
225 const flexbuffers::Vector& v =
226 flexbuffers::GetRoot(
227 reinterpret_cast<const uint8_t*>(custom_initial_data),
228 custom_initial_data_size)
229 .AsVector();
230
231 name_ = v[0].AsString().str();
232 if (!nodedef_.ParseFromString(v[1].AsString().str())) {
233 nodedef_.Clear();
234 return tensorflow::errors::Internal(
235 "Failed to parse data into a valid NodeDef");
236 }
237
238 // Fill NodeDef with defaults if it's a valid op.
239 TF_RETURN_IF_ERROR(
240 tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
241 AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
242
243 return tensorflow::Status::OK();
244 }
245
246 // Build thew new EagerOperation. In case of error, the returned 'op' is
247 // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context,tensorflow::CancellationManager * cancellation_manager)248 tensorflow::Status BuildEagerOp(
249 tensorflow::EagerContext* eager_context,
250 tensorflow::CancellationManager* cancellation_manager) {
251 op_.reset(new tensorflow::EagerOperation(eager_context));
252 TF_RETURN_IF_ERROR(op_->Reset(name_.c_str(), nullptr, false, nullptr));
253 if (op_->is_function()) {
254 op_.reset();
255 return tensorflow::errors::NotFound(
256 "Operation '", name_,
257 "' is not registered. (while processing attributes of '", name_,
258 "')");
259 }
260
261 op_->MutableAttrs()->NumInputs(inputs_.Size());
262 for (const auto& attr : nodedef_.attr()) {
263 op_->MutableAttrs()->Set(attr.first, attr.second);
264 }
265
266 // Precalculating a cache key saves about 10% of inference time for very
267 // small models.
268 op_->MutableAttrs()->CacheKey(op_->DeviceName());
269
270 op_->SetCancellationManager(cancellation_manager);
271
272 return tensorflow::Status::OK();
273 }
274
ClearEagerInputs()275 void ClearEagerInputs() { op_->Clear(); }
276
BuildEagerInputs(const BufferMap * buffer_map)277 tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
278 absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
279 TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
280 for (int i = 0; i < inputs_.Size(); ++i) {
281 int input_index = inputs_.TfLiteIndex(i);
282 TensorSource s = inputs_.GetTensorSource(i);
283 if (!s.node) {
284 // This input is not produced by this Eager subgraph (it could be a TF
285 // Lite native buffer, or could be produced by a separater subgraph). We
286 // need to fetch it from the delegate's buffer_map.
287 if (!buffer_map->HasTensor(input_index)) {
288 return tensorflow::errors::Internal(
289 "Cannot read from invalid tensor index ", input_index);
290 }
291 tensorflow::TensorHandle* handle =
292 tensorflow::TensorHandle::CreateLocalHandle(
293 buffer_map->GetTensor(input_index));
294 op_inputs->push_back(handle);
295 } else {
296 // If this is a forwardable tensor, we will remove it from the previous
297 // op's list, giving TF the opportunity to reuse its buffer.
298 bool unref_handle = inputs_.IsForwardable(i);
299 auto* handle =
300 s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
301 op_inputs->push_back(handle);
302 }
303 }
304 return tensorflow::Status::OK();
305 }
306
PersistEagerOutputs(BufferMap * buffer_map)307 tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
308 auto* handles = outputs_.GetTensorHandles();
309 for (int i = 0; i < outputs_.Size(); ++i) {
310 if (outputs_.IsSubgraphOutput(i)) {
311 const tensorflow::Tensor* tensor = nullptr;
312 TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
313 buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
314 }
315 }
316 return tensorflow::Status::OK();
317 }
318
319 private:
320 OpNode(const OpNode&) = delete;
321 OpNode& operator=(const OpNode&) = delete;
322
323 // The name of the TensorFlow op to execute.
324 string name_;
325 // Index of this node into TF Lite's operator list.
326 int index_;
327 // The corresponding NodeDef, containing the attributes for the op.
328 tensorflow::NodeDef nodedef_;
329 // The corresponding OpRegistrationData pointer.
330 const tensorflow::OpRegistrationData* op_reg_data_;
331 // List of inputs, as TF Lite tensor indices.
332 OpInputs inputs_;
333 // List of outputs, as TF Lite tensor indices.
334 OpOutputs outputs_;
335
336 std::unique_ptr<tensorflow::EagerOperation> op_;
337 };
338
339 // Executes the TensorFlow op given by 'op_name', with the attributes specified
340 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)341 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
342 OpNode* node_data) {
343 TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
344 " (while executing '", node_data->name(),
345 "' via Eager)");
346
347 node_data->mutable_outputs()->ResetTensorHandles();
348 int num_retvals = node_data->NumOutputs();
349 TF_RETURN_WITH_CONTEXT_IF_ERROR(
350 node_data->op()->Execute(
351 absl::MakeSpan(
352 reinterpret_cast<tensorflow::AbstractTensorHandle**>(
353 node_data->mutable_outputs()->GetTensorHandles()->data()),
354 num_retvals),
355 &num_retvals),
356 " (while executing '", node_data->name(), "' via Eager)");
357
358 if (num_retvals != node_data->NumOutputs()) {
359 return tensorflow::errors::Internal(
360 "Unexpected number of outputs from EagerExecute");
361 }
362
363 TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
364
365 node_data->ClearEagerInputs();
366
367 return tensorflow::Status::OK();
368 }
369
370 // The larger 'op', which contains all the nodes in a supported subgraph.
371 struct OpData {
372 tensorflow::EagerContext* eager_context;
373 tensorflow::CancellationManager* cancellation_manager;
374 BufferMap* buffer_map;
375 std::vector<std::unique_ptr<OpNode>> nodes;
376 std::vector<int> subgraph_inputs;
377 std::vector<int> subgraph_outputs;
378 };
379
DelegateKernel()380 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()381 DelegateKernel::~DelegateKernel() {}
382
Init(TfLiteContext * context,const TfLiteDelegateParams * params)383 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
384 const TfLiteDelegateParams* params) {
385 auto* flex_delegate_data =
386 reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
387 op_data_->eager_context = flex_delegate_data->GetEagerContext();
388 op_data_->cancellation_manager = flex_delegate_data->GetCancellationManager();
389 op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
390
391 CHECK(params->output_tensors);
392 std::set<int> output_set;
393 for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
394 op_data_->subgraph_outputs.push_back(tensor_index);
395 output_set.insert(tensor_index);
396 }
397
398 CHECK(params->input_tensors);
399 for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
400 op_data_->subgraph_inputs.push_back(tensor_index);
401 }
402
403 op_data_->nodes.reserve(params->nodes_to_replace->size);
404
405 CHECK(params->nodes_to_replace);
406 tensorflow::Status status;
407 for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
408 TfLiteNode* node;
409 TfLiteRegistration* reg;
410 context->GetNodeAndRegistration(context, node_index, &node, ®);
411
412 op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
413 OpNode& node_data = *op_data_->nodes.back();
414
415 node_data.set_index(node_index);
416 node_data.set_name("");
417
418 status = node_data.InitializeNodeDef(node->custom_initial_data,
419 node->custom_initial_data_size);
420 if (!status.ok()) break;
421 status = node_data.BuildEagerOp(op_data_->eager_context,
422 op_data_->cancellation_manager);
423 if (!status.ok()) break;
424 }
425
426 TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
427
428 // Given a TfLite tensor index, return the OpNode that produces it,
429 // along with it index into that OpNodes list of outputs.
430 std::map<int, TensorSource> tflite_tensor_sources;
431
432 // Find out how each tensor is produced. This does not account for
433 // tensors that are not produce by eager ops.
434 for (auto& node_data : op_data_->nodes) {
435 node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
436 for (int i = 0; i < node_data->outputs().Size(); ++i) {
437 int output_index = node_data->outputs().TfLiteIndex(i);
438 tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
439 }
440 }
441
442 // For each node, resolve the inputs, so we can keep pointers to the nodes
443 // that produces them.
444 for (auto& node_data : op_data_->nodes) {
445 node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
446 }
447 return kTfLiteOk;
448 }
449
Prepare(TfLiteContext * context,TfLiteNode * node)450 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
451 TF_LITE_ENSURE_MSG(
452 context, op_data_->eager_context != nullptr,
453 "Failed to initialize eager context. This often happens when a CPU "
454 "device has not been registered, presumably because some symbols from "
455 "tensorflow/core:core_cpu_impl were not linked into the binary.");
456
457 // We will keep track of the number of references to each tensor in the
458 // graph, so we can make them "forwardable" if there is only one reference.
459 std::map<int, int> tensor_ref_count;
460
461 // Whenever we find a constant tensor, insert it in the buffer map.
462 BufferMap* buffer_map = op_data_->buffer_map;
463 for (auto tensor_index : op_data_->subgraph_inputs) {
464 TfLiteTensor* tensor = &context->tensors[tensor_index];
465 if (IsConstantTensor(tensor)) {
466 if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
467 buffer_map->SetFromTfLite(tensor_index, tensor);
468 }
469 }
470
471 // Input tensors should never be forwarded so we increment their ref counts
472 // twice: once for this graph and another for the possibility of them being
473 // used by another subgraph, or being an output of the full graph.
474 tensor_ref_count[tensor_index] += 2;
475 }
476
477 const bool shapes_are_valid =
478 (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
479 if (shapes_are_valid) {
480 TFLITE_LOG(tflite::TFLITE_LOG_INFO,
481 "FlexDelegate: All tensor shapes are consistent.");
482 } else {
483 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
484 "FlexDelegate: Some tensor shapes are inconsistent.");
485 }
486
487 // All output tensors are allocated by TensorFlow/Eager, so we
488 // mark them as kTfLiteDynamic.
489 for (auto tensor_index : op_data_->subgraph_outputs) {
490 if (!shapes_are_valid) {
491 SetTensorToDynamic(&context->tensors[tensor_index]);
492 }
493 ++tensor_ref_count[tensor_index];
494 }
495
496 for (const auto& node_data : op_data_->nodes) {
497 if (node_data->nodedef().op().empty()) {
498 context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
499 node_data->name().c_str());
500 return kTfLiteError;
501 }
502 TF_LITE_ENSURE(context, node_data->op());
503
504 for (int i = 0; i < node_data->inputs().Size(); ++i) {
505 ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
506 }
507 }
508
509 // All tensors that are referenced exactly once are marked as "forwardable",
510 // meaning that we will allow TensorFlow to reuse its buffer as the output of
511 // an op.
512 for (auto& node_data : op_data_->nodes) {
513 for (int i = 0; i < node_data->inputs().Size(); ++i) {
514 bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
515 node_data->mutable_inputs()->SetForwardable(i, f);
516 }
517 }
518
519 return kTfLiteOk;
520 }
521
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const522 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
523 TfLiteContext* context) const {
524 for (const auto& node_data : op_data_->nodes) {
525 auto op_name = node_data->name().c_str();
526 // Create an InferenceContext object.
527 auto num_inputs = node_data->inputs().Size();
528 std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
529 nullptr);
530 InferenceContext c(
531 TF_GRAPH_DEF_VERSION, node_data->nodedef(),
532 node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
533 input_tensors_vector, {},
534 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
535
536 // Set input_shapes for ShapeInferenceFn.
537 for (int i = 0; i < num_inputs; ++i) {
538 const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
539 TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
540 // Provide constant input tensors since some op ("RFFT") needs it to
541 // calculate the output shape.
542 if (IsConstantTensor(tfl_tensor)) {
543 input_tensors_vector[i] =
544 op_data_->buffer_map->GetTensorPtr(input_tensor_index);
545 }
546 const auto dims_array = tfl_tensor->dims;
547 std::vector<DimensionHandle> dims(dims_array->size);
548 for (int j = 0; j < dims_array->size; ++j) {
549 dims[j] = c.MakeDim(dims_array->data[j]);
550 }
551 c.SetInput(i, c.MakeShape(dims));
552 }
553 c.set_input_tensors(input_tensors_vector);
554
555 tensorflow::Status status = c.construction_status();
556 if (!status.ok()) {
557 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
558 "Shape construction failed for op '%s'", op_name);
559 return kTfLiteError;
560 }
561
562 // Run ShapeInferenceFn to calculate output shapes.
563 if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
564 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
565 "No shape inference function exists for op '%s'", op_name);
566 return kTfLiteError;
567 }
568 status = c.Run(node_data->op_reg_data()->shape_inference_fn);
569
570 // Compare calculated output shapes with node_data->outputs
571 auto num_outputs = node_data->outputs().Size();
572 if (num_outputs != c.num_outputs()) {
573 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
574 "Number of output tensors are mismatched for op '%s' %d != %d",
575 op_name, num_outputs, c.num_outputs());
576 return kTfLiteError;
577 }
578 for (int i = 0; i < num_outputs; ++i) {
579 const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
580 TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
581 // tfl_tensor->dims only has valid information if the given model is
582 // converted by the MLIR converter. Also when ResizeInputTensor() is
583 // called the dims information becomes invalid.
584 const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims);
585 const std::string calculated_shape_string = c.DebugString(c.output(i));
586 // Getting a shape string via c.DebugString() is the easiest way to get
587 // the shape information of the given ShapeHandle for now.
588 // TODO(b/169017408): Find a better approach without using debug string.
589 if (tfl_shape_string != calculated_shape_string) {
590 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
591 "op '%s' output%d tensor#%d shape mismatch for %s != %s",
592 op_name, i, output_tensor_index, tfl_shape_string.c_str(),
593 calculated_shape_string.c_str());
594 return kTfLiteError;
595 }
596 }
597 }
598 return kTfLiteOk;
599 }
600
Eval(TfLiteContext * context,TfLiteNode * node)601 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
602 BufferMap* buffer_map = op_data_->buffer_map;
603
604 // Insert a tensor in the buffer map for all inputs that are not constant.
605 // Constants were handled in Prepare() already.
606 for (auto tensor_index : op_data_->subgraph_inputs) {
607 TfLiteTensor* tensor = &context->tensors[tensor_index];
608 if (!IsConstantTensor(tensor)) {
609 // If this tensor is part of an earlier TF subgraph we should not add it
610 // to the BufferMap again, because TF already knows about it and its
611 // contents are kept automatically up-to-date.
612 if (!tensor->data_is_stale || !buffer_map->HasTensor(tensor_index)) {
613 buffer_map->SetFromTfLite(tensor_index, tensor);
614 }
615 }
616 }
617
618 // Execute the TensorFlow Ops sequentially.
619 for (auto& node_data : op_data_->nodes) {
620 TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
621 reinterpret_cast<Profiler*>(context->profiler),
622 node_data->name().c_str(), node_data->index());
623
624 if (op_data_->cancellation_manager != nullptr &&
625 op_data_->cancellation_manager->IsCancelled()) {
626 TF_LITE_KERNEL_LOG(context, "Client requested cancel during Invoke()");
627 return kTfLiteError;
628 }
629
630 auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
631 TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
632 }
633
634 for (auto tensor_index : op_data_->subgraph_outputs) {
635 if (!buffer_map->HasTensor(tensor_index)) {
636 context->ReportError(context, "Cannot write to invalid tensor index %d",
637 tensor_index);
638 return kTfLiteError;
639 }
640
641 // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
642 // For dynamic tensors, copy shape and put buffer_handle for the later
643 // CopyFromBufferHandle() call.
644 TfLiteTensor* tensor = &context->tensors[tensor_index];
645 const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
646 if (tensor->allocation_type == kTfLiteDynamic) {
647 TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
648 tensor->buffer_handle = tensor_index;
649 tensor->data_is_stale = true;
650 continue;
651 }
652 // If the tensor isn't dynamic, we can copy data directly to the buffer of
653 // the tensor. Before copying the data, check if the target buffer has
654 // expected size.
655 if (tf_tensor.NumElements() != NumElements(tensor) ||
656 tf_tensor.TotalBytes() != tensor->bytes) {
657 TF_LITE_KERNEL_LOG(
658 context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)",
659 tensor->name, tensor_index, tf_tensor.TotalBytes(),
660 tf_tensor.NumElements(), tensor->bytes, NumElements(tensor));
661 return kTfLiteError;
662 }
663 tensorflow::StringPiece t_data = tf_tensor.tensor_data();
664 memcpy(tensor->data.raw, t_data.data(), t_data.size());
665 }
666
667 return kTfLiteOk;
668 }
669
670 } // namespace flex
671 } // namespace tflite
672