1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/kernel.h"
16
17 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
18 #include "tensorflow/core/common_runtime/eager/context.h"
19 #include "tensorflow/core/common_runtime/eager/execute.h"
20 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/node_def_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/lite/builtin_ops.h"
25 #include "tensorflow/lite/c/common.h"
26 #include "tensorflow/lite/context_util.h"
27 #include "tensorflow/lite/core/api/profiler.h"
28 #include "tensorflow/lite/delegates/flex/delegate.h"
29 #include "tensorflow/lite/delegates/flex/delegate_data.h"
30 #include "tensorflow/lite/delegates/flex/util.h"
31 #include "tensorflow/lite/kernels/kernel_util.h"
32 #include "tensorflow/lite/minimal_logging.h"
33 #include "tensorflow/lite/string_type.h"
34
35 // Note: this is part of TF Lite's Flex delegation code which is to be
36 // completed soon.
37
38 // This is the TF Lite op that is created by the flex delegate to handle
39 // execution of a supported subgraph. The usual flow is that the delegate
40 // informs the interpreter of supported nodes in a graph, and each supported
41 // subgraph is replaced with one instance of this kernel.
42 //
43 // The kernel is initialized with TfLiteDelegateParams from which we retrieve
44 // the global EagerContext and BufferMap, as well as a list of inputs and
45 // outputs to the subgraph. Those are used to build the OpData, with a list of
46 // TensorFlow Ops that should be executed in order (which we call an OpNode).
47 //
48 // For each node included in the subgraph, we query the interpreter and
49 // retrieve the associated NodeDef, which is then used to configure the
50 // corresponding TensorFlow/Eager Op.
51
52 using tensorflow::shape_inference::DimensionHandle;
53 using tensorflow::shape_inference::InferenceContext;
54 using tensorflow::shape_inference::ShapeAndType;
55 using tensorflow::shape_inference::ShapeHandle;
56
GetDimsDebugString(const TfLiteIntArray * dims)57 const std::string GetDimsDebugString(const TfLiteIntArray* dims) {
58 return absl::StrCat("[", absl::StrJoin(tflite::TfLiteIntArrayView(dims), ","),
59 "]");
60 }
61
62 namespace tflite {
63 namespace flex {
64
65 struct OpNode;
66
67 // Represents the origin of a given tensor as a reference to the output
68 // of an upstream node.
69 struct TensorSource {
70 OpNode* node;
71 int node_output_index;
72 };
73
74 // A list of inputs of a given node of the TensorFlow/Eager graph.
75 class OpInputs {
76 public:
OpInputs(const TfLiteIntArray * indexes)77 explicit OpInputs(const TfLiteIntArray* indexes) {
78 for (int index : TfLiteIntArrayView(indexes)) {
79 inputs_.push_back(index);
80 }
81 forwardable_.resize(inputs_.size());
82 }
~OpInputs()83 ~OpInputs() {}
84
Size() const85 int Size() const { return inputs_.size(); }
86
TfLiteIndex(int i) const87 int TfLiteIndex(int i) const { return inputs_[i]; }
88
89 // Given a map relating tensors to the node that originates them, populate a
90 // list of sources for the tensors in this class.
InitializeTensorSources(const std::map<int,TensorSource> & tflite_tensor_sources)91 void InitializeTensorSources(
92 const std::map<int, TensorSource>& tflite_tensor_sources) {
93 sources_.clear();
94 for (int i : inputs_) {
95 auto it = tflite_tensor_sources.find(i);
96 if (it == tflite_tensor_sources.end()) {
97 sources_.push_back({nullptr, 0});
98 } else {
99 sources_.push_back(it->second);
100 }
101 }
102 }
103
SetForwardable(int i,bool v)104 void SetForwardable(int i, bool v) { forwardable_[i] = v; }
105
IsForwardable(int i) const106 bool IsForwardable(int i) const { return forwardable_[i]; }
107
GetTensorSource(int i) const108 TensorSource GetTensorSource(int i) const { return sources_[i]; }
109
110 private:
111 std::vector<int> inputs_;
112 std::vector<TensorSource> sources_;
113
114 // List of tensors that can be used by TF in its forwarding optimization.
115 // Doing so allows an input tensor to be modified and used as the output
116 // tensor. The delegate takes care of not holding any references to tensors
117 // in this list while Eager is executing the corresponding op.
118 std::vector<int> forwardable_;
119 };
120
121 // A list of outputs of a given node of the TensorFlow/Eager graph, along with
122 // the actual outputs of the EagerOperation.
123 class OpOutputs {
124 public:
OpOutputs(const TfLiteIntArray * indexes)125 explicit OpOutputs(const TfLiteIntArray* indexes) {
126 for (int index : TfLiteIntArrayView(indexes)) {
127 outputs_.push_back(index);
128 }
129 vector_.resize(outputs_.size());
130 }
~OpOutputs()131 ~OpOutputs() { ResetTensorHandles(); }
132
133 // Stores information about which of the tensors in this class are also
134 // outputs of the sugbraph.
InitializeGraphOutputs(const std::set<int> & subgraph_outputs)135 void InitializeGraphOutputs(const std::set<int>& subgraph_outputs) {
136 subgraph_outputs_.clear();
137 for (int i : outputs_) {
138 subgraph_outputs_.push_back(subgraph_outputs.count(i) > 0);
139 }
140 }
141
142 // Returns true if the tensor given by index 'i' is an output of the entire
143 // subgraph.
IsSubgraphOutput(int i) const144 bool IsSubgraphOutput(int i) const { return subgraph_outputs_[i]; }
145
146 // Returns a handle to a given tensor and, optionally, remove it from the
147 // internal vector.
GetHandle(int i,bool remove)148 tensorflow::TensorHandle* GetHandle(int i, bool remove) {
149 auto* handle = vector_[i];
150 if (!remove) {
151 handle->Ref();
152 } else {
153 // Don't increase the ref-count. Instead, simply take it out of the
154 // vector.
155 vector_[i] = nullptr;
156 }
157 return handle;
158 }
159
Size() const160 int Size() const { return outputs_.size(); }
161
TfLiteIndex(int i) const162 int TfLiteIndex(int i) const { return outputs_[i]; }
163
164 // Carefully unreference all the handles in the eager output vector.
ResetTensorHandles()165 void ResetTensorHandles() {
166 for (int i = 0; i < vector_.size(); ++i) {
167 if (vector_[i]) {
168 vector_[i]->Unref();
169 vector_[i] = nullptr;
170 }
171 }
172 }
173
174 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2>*
GetTensorHandles()175 GetTensorHandles() {
176 return &vector_;
177 }
178
179 private:
180 std::vector<int> outputs_;
181 std::vector<bool> subgraph_outputs_;
182 tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> vector_;
183 };
184
185 // A single node within the larger 'op'. Note that this kernel executes many
186 // TensorFlow ops within a single TF Lite op.
187 class OpNode {
188 public:
OpNode(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs)189 OpNode(const TfLiteIntArray* inputs, const TfLiteIntArray* outputs)
190 : inputs_(inputs), outputs_(outputs) {}
~OpNode()191 ~OpNode() {
192 if (op_) ClearEagerInputs();
193 }
194
name() const195 const string& name() const { return name_; }
set_name(const string & name)196 void set_name(const string& name) { name_ = name; }
197
index() const198 int index() const { return index_; }
set_index(int index)199 void set_index(int index) { index_ = index; }
200
nodedef() const201 const tensorflow::NodeDef& nodedef() const { return nodedef_; }
op_reg_data() const202 const tensorflow::OpRegistrationData* op_reg_data() const {
203 return op_reg_data_;
204 }
205
inputs() const206 const OpInputs& inputs() const { return inputs_; }
mutable_inputs()207 OpInputs* mutable_inputs() { return &inputs_; }
208
outputs() const209 const OpOutputs& outputs() const { return outputs_; }
mutable_outputs()210 OpOutputs* mutable_outputs() { return &outputs_; }
211
NumInputs() const212 int NumInputs() const { return inputs_.Size(); }
NumOutputs() const213 int NumOutputs() const { return outputs_.Size(); }
214
op()215 tensorflow::EagerOperation* op() { return op_.get(); }
216
InitializeNodeDef(const void * custom_initial_data,int custom_initial_data_size)217 tensorflow::Status InitializeNodeDef(const void* custom_initial_data,
218 int custom_initial_data_size) {
219 if (!custom_initial_data) {
220 return tensorflow::errors::Internal(
221 "Cannot convert empty data into a valid NodeDef");
222 }
223 // The flexbuffer contains a vector where the first elements is the
224 // op name and the second is a serialized NodeDef.
225 const flexbuffers::Vector& v =
226 flexbuffers::GetRoot(
227 reinterpret_cast<const uint8_t*>(custom_initial_data),
228 custom_initial_data_size)
229 .AsVector();
230
231 name_ = v[0].AsString().str();
232 if (!nodedef_.ParseFromString(v[1].AsString().str())) {
233 nodedef_.Clear();
234 return tensorflow::errors::Internal(
235 "Failed to parse data into a valid NodeDef");
236 }
237
238 // Fill NodeDef with defaults if it's a valid op.
239 TF_RETURN_IF_ERROR(
240 tensorflow::OpRegistry::Global()->LookUp(nodedef_.op(), &op_reg_data_));
241 AddDefaultsToNodeDef(op_reg_data_->op_def, &nodedef_);
242
243 return tensorflow::Status::OK();
244 }
245
246 // Build thew new EagerOperation. In case of error, the returned 'op' is
247 // guaranteed to be 'nullptr'.
BuildEagerOp(tensorflow::EagerContext * eager_context)248 tensorflow::Status BuildEagerOp(tensorflow::EagerContext* eager_context) {
249 op_.reset(new tensorflow::EagerOperation(eager_context));
250 TF_RETURN_IF_ERROR(op_->Reset(name_.c_str(), nullptr, false, nullptr));
251 if (op_->is_function()) {
252 op_.reset();
253 return tensorflow::errors::NotFound(
254 "Operation '", name_,
255 "' is not registered. (while processing attributes of '", name_,
256 "')");
257 }
258
259 op_->MutableAttrs()->NumInputs(inputs_.Size());
260 for (const auto& attr : nodedef_.attr()) {
261 op_->MutableAttrs()->Set(attr.first, attr.second);
262 }
263
264 // Precalculating a cache key saves about 10% of inference time for very
265 // small models.
266 op_->MutableAttrs()->CacheKey(op_->DeviceName());
267
268 return tensorflow::Status::OK();
269 }
270
ClearEagerInputs()271 void ClearEagerInputs() { op_->Clear(); }
272
BuildEagerInputs(const BufferMap * buffer_map)273 tensorflow::Status BuildEagerInputs(const BufferMap* buffer_map) {
274 absl::InlinedVector<tensorflow::TensorHandle*, 4>* op_inputs;
275 TF_RETURN_IF_ERROR(op_->MutableTensorHandleInputs(&op_inputs));
276 for (int i = 0; i < inputs_.Size(); ++i) {
277 int input_index = inputs_.TfLiteIndex(i);
278 TensorSource s = inputs_.GetTensorSource(i);
279 if (!s.node) {
280 // This input is not produced by this Eager subgraph (it could be a TF
281 // Lite native buffer, or could be produced by a separater subgraph). We
282 // need to fetch it from the delegate's buffer_map.
283 if (!buffer_map->HasTensor(input_index)) {
284 return tensorflow::errors::Internal(
285 "Cannot read from invalid tensor index ", input_index);
286 }
287 tensorflow::TensorHandle* handle =
288 tensorflow::TensorHandle::CreateLocalHandle(
289 buffer_map->GetTensor(input_index));
290 op_inputs->push_back(handle);
291 } else {
292 // If this is a forwardable tensor, we will remove it from the previous
293 // op's list, giving TF the opportunity to reuse its buffer.
294 bool unref_handle = inputs_.IsForwardable(i);
295 auto* handle =
296 s.node->outputs_.GetHandle(s.node_output_index, unref_handle);
297 op_inputs->push_back(handle);
298 }
299 }
300 return tensorflow::Status::OK();
301 }
302
PersistEagerOutputs(BufferMap * buffer_map)303 tensorflow::Status PersistEagerOutputs(BufferMap* buffer_map) {
304 auto* handles = outputs_.GetTensorHandles();
305 for (int i = 0; i < outputs_.Size(); ++i) {
306 if (outputs_.IsSubgraphOutput(i)) {
307 const tensorflow::Tensor* tensor = nullptr;
308 TF_RETURN_IF_ERROR(handles->at(i)->Tensor(&tensor));
309 buffer_map->SetFromTensorFlow(outputs_.TfLiteIndex(i), *tensor);
310 }
311 }
312 return tensorflow::Status::OK();
313 }
314
315 private:
316 OpNode(const OpNode&) = delete;
317 OpNode& operator=(const OpNode&) = delete;
318
319 // The name of the TensorFlow op to execute.
320 string name_;
321 // Index of this node into TF Lite's operator list.
322 int index_;
323 // The corresponding NodeDef, containing the attributes for the op.
324 tensorflow::NodeDef nodedef_;
325 // The corresponding OpRegistrationData pointer.
326 const tensorflow::OpRegistrationData* op_reg_data_;
327 // List of inputs, as TF Lite tensor indices.
328 OpInputs inputs_;
329 // List of outputs, as TF Lite tensor indices.
330 OpOutputs outputs_;
331
332 std::unique_ptr<tensorflow::EagerOperation> op_;
333 };
334
335 // Executes the TensorFlow op given by 'op_name', with the attributes specified
336 // in 'nodedef'. Inputs and outputs are given as indices into the 'buffer_map'.
ExecuteFlexOp(TfLiteContext * context,BufferMap * buffer_map,OpNode * node_data)337 tensorflow::Status ExecuteFlexOp(TfLiteContext* context, BufferMap* buffer_map,
338 OpNode* node_data) {
339 TF_RETURN_WITH_CONTEXT_IF_ERROR(node_data->BuildEagerInputs(buffer_map),
340 " (while executing '", node_data->name(),
341 "' via Eager)");
342
343 node_data->mutable_outputs()->ResetTensorHandles();
344 int num_retvals = node_data->NumOutputs();
345 TF_RETURN_WITH_CONTEXT_IF_ERROR(
346 node_data->op()->Execute(
347 absl::MakeSpan(
348 reinterpret_cast<tensorflow::AbstractTensorHandle**>(
349 node_data->mutable_outputs()->GetTensorHandles()->data()),
350 num_retvals),
351 &num_retvals),
352 " (while executing '", node_data->name(), "' via Eager)");
353
354 if (num_retvals != node_data->NumOutputs()) {
355 return tensorflow::errors::Internal(
356 "Unexpected number of outputs from EagerExecute");
357 }
358
359 TF_RETURN_IF_ERROR(node_data->PersistEagerOutputs(buffer_map));
360
361 node_data->ClearEagerInputs();
362
363 return tensorflow::Status::OK();
364 }
365
366 // The larger 'op', which contains all the nodes in a supported subgraph.
367 struct OpData {
368 tensorflow::EagerContext* eager_context;
369 BufferMap* buffer_map;
370 std::vector<std::unique_ptr<OpNode>> nodes;
371 std::vector<int> subgraph_inputs;
372 std::vector<int> subgraph_outputs;
373 };
374
DelegateKernel()375 DelegateKernel::DelegateKernel() : op_data_(new OpData) {}
~DelegateKernel()376 DelegateKernel::~DelegateKernel() {}
377
Init(TfLiteContext * context,const TfLiteDelegateParams * params)378 TfLiteStatus DelegateKernel::Init(TfLiteContext* context,
379 const TfLiteDelegateParams* params) {
380 auto* flex_delegate_data =
381 reinterpret_cast<FlexDelegate*>(params->delegate->data_)->mutable_data();
382 op_data_->eager_context = flex_delegate_data->GetEagerContext();
383 op_data_->buffer_map = flex_delegate_data->GetBufferMap(context);
384
385 CHECK(params->output_tensors);
386 std::set<int> output_set;
387 for (auto tensor_index : TfLiteIntArrayView(params->output_tensors)) {
388 op_data_->subgraph_outputs.push_back(tensor_index);
389 output_set.insert(tensor_index);
390 }
391
392 CHECK(params->input_tensors);
393 for (auto tensor_index : TfLiteIntArrayView(params->input_tensors)) {
394 op_data_->subgraph_inputs.push_back(tensor_index);
395 }
396
397 op_data_->nodes.reserve(params->nodes_to_replace->size);
398
399 CHECK(params->nodes_to_replace);
400 tensorflow::Status status;
401 for (auto node_index : TfLiteIntArrayView(params->nodes_to_replace)) {
402 TfLiteNode* node;
403 TfLiteRegistration* reg;
404 context->GetNodeAndRegistration(context, node_index, &node, ®);
405
406 op_data_->nodes.emplace_back(new OpNode(node->inputs, node->outputs));
407 OpNode& node_data = *op_data_->nodes.back();
408
409 node_data.set_index(node_index);
410 node_data.set_name("");
411
412 status = node_data.InitializeNodeDef(node->custom_initial_data,
413 node->custom_initial_data_size);
414 if (!status.ok()) break;
415 status = node_data.BuildEagerOp(op_data_->eager_context);
416 if (!status.ok()) break;
417 }
418
419 TF_LITE_ENSURE_STATUS(ConvertStatus(context, status));
420
421 // Given a TfLite tensor index, return the OpNode that produces it,
422 // along with it index into that OpNodes list of outputs.
423 std::map<int, TensorSource> tflite_tensor_sources;
424
425 // Find out how each tensor is produced. This does not account for
426 // tensors that are not produce by eager ops.
427 for (auto& node_data : op_data_->nodes) {
428 node_data->mutable_outputs()->InitializeGraphOutputs(output_set);
429 for (int i = 0; i < node_data->outputs().Size(); ++i) {
430 int output_index = node_data->outputs().TfLiteIndex(i);
431 tflite_tensor_sources[output_index] = TensorSource{node_data.get(), i};
432 }
433 }
434
435 // For each node, resolve the inputs, so we can keep pointers to the nodes
436 // that produces them.
437 for (auto& node_data : op_data_->nodes) {
438 node_data->mutable_inputs()->InitializeTensorSources(tflite_tensor_sources);
439 }
440 return kTfLiteOk;
441 }
442
Prepare(TfLiteContext * context,TfLiteNode * node)443 TfLiteStatus DelegateKernel::Prepare(TfLiteContext* context, TfLiteNode* node) {
444 TF_LITE_ENSURE_MSG(
445 context, op_data_->eager_context != nullptr,
446 "Failed to initialize eager context. This often happens when a CPU "
447 "device has not been registered, presumably because some symbols from "
448 "tensorflow/core:core_cpu_impl were not linked into the binary.");
449
450 // We will keep track of the number of references to each tensor in the
451 // graph, so we can make them "forwardable" if there is only one reference.
452 std::map<int, int> tensor_ref_count;
453
454 // Whenever we find a constant tensor, insert it in the buffer map.
455 BufferMap* buffer_map = op_data_->buffer_map;
456 for (auto tensor_index : op_data_->subgraph_inputs) {
457 TfLiteTensor* tensor = &context->tensors[tensor_index];
458 if (IsConstantTensor(tensor)) {
459 if (!buffer_map->HasTensor(tensor_index)) {
460 buffer_map->SetFromTfLite(tensor_index, tensor);
461 }
462 }
463
464 // Input tensors should never be forwarded so we increment their ref counts
465 // twice: once for this graph and another for the possibility of them being
466 // used by another subgraph, or being an output of the full graph.
467 tensor_ref_count[tensor_index] += 2;
468 }
469
470 const bool shapes_are_valid =
471 (ValidateOutputTensorShapeConsistency(context) == kTfLiteOk);
472 if (shapes_are_valid) {
473 TFLITE_LOG(tflite::TFLITE_LOG_INFO,
474 "FlexDelegate: All tensor shapes are consistent.");
475 } else {
476 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
477 "FlexDelegate: Some tensor shapes are inconsistent.");
478 }
479
480 // All output tensors are allocated by TensorFlow/Eager, so we
481 // mark them as kTfLiteDynamic.
482 for (auto tensor_index : op_data_->subgraph_outputs) {
483 if (!shapes_are_valid) {
484 SetTensorToDynamic(&context->tensors[tensor_index]);
485 }
486 ++tensor_ref_count[tensor_index];
487 }
488
489 for (const auto& node_data : op_data_->nodes) {
490 if (node_data->nodedef().op().empty()) {
491 context->ReportError(context, "Invalid NodeDef in Flex op '%s'",
492 node_data->name().c_str());
493 return kTfLiteError;
494 }
495 TF_LITE_ENSURE(context, node_data->op());
496
497 for (int i = 0; i < node_data->inputs().Size(); ++i) {
498 ++tensor_ref_count[node_data->inputs().TfLiteIndex(i)];
499 }
500 }
501
502 // All tensors that are referenced exactly once are marked as "forwardable",
503 // meaning that we will allow TensorFlow to reuse its buffer as the output of
504 // an op.
505 for (auto& node_data : op_data_->nodes) {
506 for (int i = 0; i < node_data->inputs().Size(); ++i) {
507 bool f = (tensor_ref_count[node_data->inputs().TfLiteIndex(i)] == 1);
508 node_data->mutable_inputs()->SetForwardable(i, f);
509 }
510 }
511
512 return kTfLiteOk;
513 }
514
ValidateOutputTensorShapeConsistency(TfLiteContext * context) const515 TfLiteStatus DelegateKernel::ValidateOutputTensorShapeConsistency(
516 TfLiteContext* context) const {
517 for (const auto& node_data : op_data_->nodes) {
518 auto op_name = node_data->name().c_str();
519 // Create an InferenceContext object.
520 auto num_inputs = node_data->inputs().Size();
521 std::vector<const tensorflow::Tensor*> input_tensors_vector(num_inputs,
522 nullptr);
523 InferenceContext c(
524 TF_GRAPH_DEF_VERSION, node_data->nodedef(),
525 node_data->op_reg_data()->op_def, std::vector<ShapeHandle>(num_inputs),
526 input_tensors_vector, {},
527 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>());
528
529 // Set input_shapes for ShapeInferenceFn.
530 for (int i = 0; i < num_inputs; ++i) {
531 const auto input_tensor_index = node_data->inputs().TfLiteIndex(i);
532 TfLiteTensor* tfl_tensor = &context->tensors[input_tensor_index];
533 // Provide constant input tensors since some op ("RFFT") needs it to
534 // calculate the output shape.
535 if (IsConstantTensor(tfl_tensor)) {
536 input_tensors_vector[i] =
537 op_data_->buffer_map->GetTensorPtr(input_tensor_index);
538 }
539 const auto dims_array = tfl_tensor->dims;
540 std::vector<DimensionHandle> dims(dims_array->size);
541 for (int j = 0; j < dims_array->size; ++j) {
542 dims[j] = c.MakeDim(dims_array->data[j]);
543 }
544 c.SetInput(i, c.MakeShape(dims));
545 }
546 c.set_input_tensors(input_tensors_vector);
547
548 tensorflow::Status status = c.construction_status();
549 if (!status.ok()) {
550 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
551 "Shape construction failed for op '%s'", op_name);
552 return kTfLiteError;
553 }
554
555 // Run ShapeInferenceFn to calculate output shapes.
556 if (node_data->op_reg_data()->shape_inference_fn == nullptr) {
557 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
558 "No shape inference function exists for op '%s'", op_name);
559 return kTfLiteError;
560 }
561 status = c.Run(node_data->op_reg_data()->shape_inference_fn);
562
563 // Compare calculated output shapes with node_data->outputs
564 auto num_outputs = node_data->outputs().Size();
565 if (num_outputs != c.num_outputs()) {
566 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
567 "Number of output tensors are mismatched for op '%s' %d != %d",
568 op_name, num_outputs, c.num_outputs());
569 return kTfLiteError;
570 }
571 for (int i = 0; i < num_outputs; ++i) {
572 const auto output_tensor_index = node_data->outputs().TfLiteIndex(i);
573 TfLiteTensor* tfl_tensor = &context->tensors[output_tensor_index];
574 // tfl_tensor->dims only has valid information if the given model is
575 // converted by the MLIR converter. Also when ResizeInputTensor() is
576 // called the dims information becomes invalid.
577 const std::string tfl_shape_string = GetDimsDebugString(tfl_tensor->dims);
578 const std::string calculated_shape_string = c.DebugString(c.output(i));
579 // Getting a shape string via c.DebugString() is the easiest way to get
580 // the shape information of the given ShapeHandle for now.
581 // TODO(b/169017408): Find a better approach without using debug string.
582 if (tfl_shape_string != calculated_shape_string) {
583 TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
584 "op '%s' output%d tensor#%d shape mismatch for %s != %s",
585 op_name, i, output_tensor_index, tfl_shape_string.c_str(),
586 calculated_shape_string.c_str());
587 return kTfLiteError;
588 }
589 }
590 }
591 return kTfLiteOk;
592 }
593
Eval(TfLiteContext * context,TfLiteNode * node)594 TfLiteStatus DelegateKernel::Eval(TfLiteContext* context, TfLiteNode* node) {
595 BufferMap* buffer_map = op_data_->buffer_map;
596
597 // Insert a tensor in the buffer map for all inputs that are not constant.
598 // Constants were handled in Prepare() already.
599 for (auto tensor_index : op_data_->subgraph_inputs) {
600 TfLiteTensor* tensor = &context->tensors[tensor_index];
601 if (!IsConstantTensor(tensor)) {
602 // If this tensor is part of an earlier TF subgraph we should not add it
603 // to the BufferMap again, because TF already knows about it and its
604 // contents are kept automatically up-to-date.
605 if (!buffer_map->IsTensorFlowTensor(tensor_index)) {
606 buffer_map->SetFromTfLite(tensor_index, tensor);
607 }
608 }
609 }
610
611 // Execute the TensorFlow Ops sequentially.
612 for (auto& node_data : op_data_->nodes) {
613 TFLITE_SCOPED_DELEGATE_OPERATOR_PROFILE(
614 reinterpret_cast<Profiler*>(context->profiler),
615 node_data->name().c_str(), node_data->index());
616
617 auto status = ExecuteFlexOp(context, buffer_map, node_data.get());
618 TF_LITE_ENSURE_OK(context, ConvertStatus(context, status));
619 }
620
621 for (auto tensor_index : op_data_->subgraph_outputs) {
622 if (!buffer_map->HasTensor(tensor_index)) {
623 context->ReportError(context, "Cannot write to invalid tensor index %d",
624 tensor_index);
625 return kTfLiteError;
626 }
627
628 // Copy TF tensor data to TFL allocated buffer for non dynamic tensors.
629 // For dynamic tensors, copy shape and put buffer_handle for the later
630 // CopyFromBufferHandle() call.
631 TfLiteTensor* tensor = &context->tensors[tensor_index];
632 const tensorflow::Tensor& tf_tensor = buffer_map->GetTensor(tensor_index);
633 if (tensor->allocation_type == kTfLiteDynamic) {
634 TF_LITE_ENSURE_OK(context, CopyShapeAndType(context, tf_tensor, tensor));
635 tensor->buffer_handle = tensor_index;
636 tensor->data_is_stale = true;
637 continue;
638 }
639 // If the tensor isn't dynamic, we can copy data directly to the buffer of
640 // the tensor. Before copying the data, check if the target buffer has
641 // expected size.
642 if (tf_tensor.NumElements() != NumElements(tensor) ||
643 tf_tensor.TotalBytes() != tensor->bytes) {
644 TF_LITE_KERNEL_LOG(
645 context, "Tensor: %s(%d) buffer size mismatch %zu(%lld) != %ld(%ld)",
646 tensor->name, tensor_index, tf_tensor.TotalBytes(),
647 tf_tensor.NumElements(), tensor->bytes, NumElements(tensor));
648 return kTfLiteError;
649 }
650 tensorflow::StringPiece t_data = tf_tensor.tensor_data();
651 memcpy(tensor->data.raw, t_data.data(), t_data.size());
652 }
653
654 return kTfLiteOk;
655 }
656
657 } // namespace flex
658 } // namespace tflite
659