1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "tensorflow/lite/builtin_ops.h"
34 #include "tensorflow/lite/c/builtin_op_data.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
37 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
38 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
39 #include "tensorflow/lite/delegates/gpu/common/model.h"
40 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
41 #include "tensorflow/lite/delegates/gpu/common/model_builder_internal.h"
42 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
43 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
44 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
45 #include "tensorflow/lite/delegates/gpu/common/operations.h"
46 #include "tensorflow/lite/delegates/gpu/common/shape.h"
47 #include "tensorflow/lite/delegates/gpu/common/status.h"
48 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
49 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
50 #include "tensorflow/lite/delegates/utils.h"
51 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
52 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
53 #include "tensorflow/lite/kernels/kernel_util.h"
54 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
55 #include "tensorflow/lite/util.h"
56
57 namespace tflite {
58 namespace gpu {
59 namespace {
60
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)61 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
62 int bias_tensor_id,
63 ObjectReader* reader,
64 FullyConnectedAttributes* attr) {
65 Tensor<HW, DataType::FLOAT32> weights;
66 RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
67 attr->weights.data = std::move(weights.data);
68 attr->weights.id = weights.id;
69 attr->weights.shape.h = 1;
70 attr->weights.shape.w = 1;
71 attr->weights.shape.o = weights.shape.h;
72 attr->weights.shape.i = weights.shape.w;
73 reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
74 return absl::OkStatus();
75 }
76
77 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)78 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
79 const ParamsT** tf_options) {
80 *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
81 if (!*tf_options) {
82 return absl::InternalError("Unable to retrieve builtin_data.");
83 }
84 return absl::OkStatus();
85 }
86
87 template <typename ParamsT>
RetrieveCustomInitialData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)88 absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
89 const ParamsT** tf_options) {
90 *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
91 if (!*tf_options) {
92 return absl::InternalError("Unable to retrieve custom_initial_data.");
93 }
94 return absl::OkStatus();
95 }
96
97 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)98 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
99 ConstTensorAttributes attr;
100 attr.tensor = std::move(t);
101 Node* node = graph->NewNode();
102 node->operation.attributes = attr;
103 node->operation.type = ToString(OperationType::CONSTANT);
104 *value = graph->NewValue();
105 RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
106 // Keep data inside this tensor.
107 (*value)->tensor.ref = attr.tensor.id;
108 (*value)->tensor.type = attr.tensor.kType;
109 (*value)->tensor.shape = attr.tensor.shape;
110 return absl::OkStatus();
111 }
112
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)113 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
114 TensorOrScalar* tensor_or_scalar) {
115 const std::string& opname = node->operation.type;
116
117 // Determine runtime/constant tensors.
118 const TfLiteTensor* input0 = reader->GetInputTensor(0);
119 if (!input0) {
120 return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
121 opname);
122 }
123 const TfLiteTensor* input1 = reader->GetInputTensor(1);
124 if (!input1) {
125 return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
126 opname);
127 }
128 const bool constant_tensor0 = IsConstantTensor(input0);
129 const bool constant_tensor1 = IsConstantTensor(input1);
130 if (constant_tensor0 && constant_tensor1) {
131 return absl::InvalidArgumentError("No runtime input tensors for " + opname);
132 }
133 const bool runtime_tensor0 = !constant_tensor0;
134 const bool runtime_tensor1 = !constant_tensor1;
135
136 if (runtime_tensor0 && runtime_tensor1) {
137 RETURN_IF_ERROR(reader->AddInput(node, 0));
138 RETURN_IF_ERROR(reader->AddInput(node, 1));
139 } else {
140 int runtime_tensor = 0;
141 int constant_tensor = 1;
142 TfLiteIntArray* constant_dims = input1->dims;
143 if (constant_tensor0 && runtime_tensor1) {
144 runtime_tensor = 1;
145 constant_tensor = 0;
146 constant_dims = input0->dims;
147 }
148 RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
149 if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
150 Tensor<Scalar, DataType::FLOAT32> tensor;
151 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
152 *tensor_or_scalar = tensor.data[0];
153 } else {
154 if (CheckIfLinearConvertible(constant_dims).ok()) {
155 Tensor<Linear, DataType::FLOAT32> tensor;
156 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
157 *tensor_or_scalar = std::move(tensor);
158 } else {
159 Tensor<HWC, DataType::FLOAT32> tensor;
160 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
161 *tensor_or_scalar = std::move(tensor);
162 }
163 }
164 }
165 return absl::OkStatus();
166 }
167
168 struct TensorInfo {
169 std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> producers;
170 std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> consumers;
171 };
172
GetTensorInfo(const TfLiteContext * context,int tensor_id,TensorInfo * result)173 absl::Status GetTensorInfo(const TfLiteContext* context, int tensor_id,
174 TensorInfo* result) {
175 TfLiteIntArray* execution_plan = nullptr;
176 if (context->GetExecutionPlan(const_cast<TfLiteContext*>(context),
177 &execution_plan) != kTfLiteOk) {
178 return absl::UnavailableError("Unable to get graph execution plan.");
179 }
180 for (int i = 0; i < execution_plan->size; ++i) {
181 const int node_index = execution_plan->data[i];
182 TfLiteNode* node = nullptr;
183 TfLiteRegistration* registration = nullptr;
184 if (context->GetNodeAndRegistration(const_cast<TfLiteContext*>(context),
185 node_index, &node,
186 ®istration) != kTfLiteOk) {
187 return absl::UnavailableError(
188 "Unable to get node and registration for node.");
189 }
190 for (int j = 0; j < node->inputs->size; ++j) {
191 if (tensor_id == node->inputs->data[j]) {
192 result->consumers.push_back({node, registration});
193 }
194 }
195 for (int j = 0; j < node->outputs->size; ++j) {
196 if (tensor_id == node->outputs->data[j]) {
197 result->producers.push_back({node, registration});
198 }
199 }
200 }
201 return absl::OkStatus();
202 }
203
IsLogicalCode(int32_t builtin_code)204 bool IsLogicalCode(int32_t builtin_code) {
205 return builtin_code == kTfLiteBuiltinGreater ||
206 builtin_code == kTfLiteBuiltinGreaterEqual ||
207 builtin_code == kTfLiteBuiltinLess ||
208 builtin_code == kTfLiteBuiltinLessEqual ||
209 builtin_code == kTfLiteBuiltinEqual ||
210 builtin_code == kTfLiteBuiltinNotEqual;
211 }
212
IsLogicalOp(tflite::gpu::OperationType op_type)213 bool IsLogicalOp(tflite::gpu::OperationType op_type) {
214 return op_type == tflite::gpu::OperationType::GREATER ||
215 op_type == tflite::gpu::OperationType::GREATER_EQUAL ||
216 op_type == tflite::gpu::OperationType::LESS ||
217 op_type == tflite::gpu::OperationType::LESS_EQUAL ||
218 op_type == tflite::gpu::OperationType::EQUAL ||
219 op_type == tflite::gpu::OperationType::NOT_EQUAL;
220 }
221
222 class AddOperationParser : public TFLiteOperationParser {
223 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)224 absl::Status IsSupported(const TfLiteContext* context,
225 const TfLiteNode* tflite_node,
226 const TfLiteRegistration* registration) final {
227 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
228 // TODO(eignasheva): Add shapes check.
229 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
230 }
231
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)232 absl::Status Parse(const TfLiteNode* tflite_node,
233 const TfLiteRegistration* registration,
234 GraphFloat32* graph, ObjectReader* reader) final {
235 // TFLite currently only supports 2 input ADDs. Thus, the logic below only
236 // considers 2 input cases. The underlying GPU shader programs can accept
237 // more inputs, but the logic below would have to be expanded.
238
239 Node* node = graph->NewNode();
240 node->operation.type = ToString(OperationType::ADD);
241 RETURN_IF_ERROR(reader->AddOutputs(node));
242 ElementwiseAttributes attr;
243 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
244 node->operation.attributes = std::move(attr);
245 const TfLiteAddParams* tf_options;
246 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
247 return MaybeFuseActivation(tf_options->activation, graph, node);
248 }
249 };
250
251 class BatchedMatMulOperationParser : public TFLiteOperationParser {
252 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)253 absl::Status IsSupported(const TfLiteContext* context,
254 const TfLiteNode* tflite_node,
255 const TfLiteRegistration* registration) final {
256 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
257 }
258
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)259 absl::Status Parse(const TfLiteNode* tflite_node,
260 const TfLiteRegistration* registration,
261 GraphFloat32* graph, ObjectReader* reader) final {
262 Node* node = graph->NewNode();
263 node->operation.type = ToString(OperationType::BATCHED_MATMUL);
264 RETURN_IF_ERROR(reader->AddInput(node, 0));
265 RETURN_IF_ERROR(reader->AddInput(node, 1));
266 RETURN_IF_ERROR(reader->AddOutputs(node));
267 return absl::OkStatus();
268 }
269 };
270
271 class CastOperationParser : public TFLiteOperationParser {
272 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)273 absl::Status IsSupported(const TfLiteContext* context,
274 const TfLiteNode* tflite_node,
275 const TfLiteRegistration* registration) final {
276 TensorInfo input_tensor_info;
277 RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->inputs->data[0],
278 &input_tensor_info));
279 if (input_tensor_info.producers.size() != 1 ||
280 input_tensor_info.consumers.size() != 1) {
281 return absl::UnavailableError("Not supported cast case");
282 }
283 if (IsLogicalCode(input_tensor_info.producers[0].second->builtin_code)) {
284 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
285 }
286 return absl::UnimplementedError("Not supported Cast case.");
287 }
288
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)289 absl::Status Parse(const TfLiteNode* tflite_node,
290 const TfLiteRegistration* registration,
291 GraphFloat32* graph, ObjectReader* reader) final {
292 Node* node = graph->NewNode();
293 // Adding Identity reshape that will be removed.
294 node->operation.type = ToString(OperationType::RESHAPE);
295 RETURN_IF_ERROR(reader->AddInput(node, 0));
296 RETURN_IF_ERROR(reader->AddOutputs(node));
297 // New shape comes from output shape.
298 ReshapeAttributes attr;
299 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
300 node->operation.attributes = attr;
301 return absl::OkStatus();
302 }
303 };
304
305 class ClampOperationsParser : public TFLiteOperationParser {
306 public:
ClampOperationsParser(float clamp_a,float clamp_b)307 explicit ClampOperationsParser(float clamp_a, float clamp_b)
308 : clamp_a_(clamp_a), clamp_b_(clamp_b) {}
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)309 absl::Status IsSupported(const TfLiteContext* context,
310 const TfLiteNode* tflite_node,
311 const TfLiteRegistration* registration) final {
312 return absl::OkStatus();
313 }
314
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)315 absl::Status Parse(const TfLiteNode* tflite_node,
316 const TfLiteRegistration* registration,
317 GraphFloat32* graph, ObjectReader* reader) final {
318 // clamp(v, a, b) = clamp(v - a, 0.0, b - a) + a;
319 // We replace clamp(...) with sequence of elementwise ops:
320 // substaction -> usual relu with alpha = 0.0 -> addition.
321 // node_sub = v0 = v - a // add op (add -a)
322 // node_relu = v1 = clamp(v0, 0.0, clip); // relu op alpha = 0.0,
323 // clip = b - a;
324 // node_add = v2 = v1 + a // add op (add a)
325 Node* node_sub = graph->NewNode();
326 Node* node_relu = graph->NewNode();
327 Node* node_add = graph->NewNode();
328
329 ElementwiseAttributes sub_attr;
330 sub_attr.param = -clamp_a_;
331 node_sub->operation.type = ToString(OperationType::ADD);
332 node_sub->operation.attributes = std::move(sub_attr);
333
334 ReLUAttributes relu_attr;
335 relu_attr.alpha = 0.0f;
336 relu_attr.clip = clamp_b_ - clamp_a_;
337 node_relu->operation.type = ToString(OperationType::RELU);
338 node_relu->operation.attributes = relu_attr;
339
340 ElementwiseAttributes add_attr;
341 add_attr.param = clamp_a_;
342 node_add->operation.type = ToString(OperationType::ADD);
343 node_add->operation.attributes = std::move(add_attr);
344
345 RETURN_IF_ERROR(reader->AddInput(node_sub, 0));
346 auto input = graph->FindInputs(node_sub->id)[0];
347
348 Value* v0 = graph->NewValue();
349 Value* v1 = graph->NewValue();
350 v0->tensor.type = input->tensor.type;
351 v0->tensor.shape = input->tensor.shape;
352 v1->tensor.type = input->tensor.type;
353 v1->tensor.shape = input->tensor.shape;
354
355 RETURN_IF_ERROR(graph->SetProducer(node_sub->id, v0->id));
356 RETURN_IF_ERROR(graph->AddConsumer(node_relu->id, v0->id));
357 RETURN_IF_ERROR(graph->SetProducer(node_relu->id, v1->id));
358 RETURN_IF_ERROR(graph->AddConsumer(node_add->id, v1->id));
359
360 RETURN_IF_ERROR(reader->AddOutputs(node_add));
361 return absl::OkStatus();
362 }
363
364 private:
365 const float clamp_a_, clamp_b_;
366 };
367
368 class ConcatenationOperationParser : public TFLiteOperationParser {
369 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)370 absl::Status IsSupported(const TfLiteContext* context,
371 const TfLiteNode* tflite_node,
372 const TfLiteRegistration* registration) final {
373 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
374
375 // TODO(eignasheva): add proper tensor availability checking
376 // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
377 // RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
378 // }
379 // TODO(eignasheva): add axis checking.
380 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
381 }
382
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)383 absl::Status Parse(const TfLiteNode* tflite_node,
384 const TfLiteRegistration* registration,
385 GraphFloat32* graph, ObjectReader* reader) final {
386 ConcatAttributes attr;
387 // Read inputs first to make sure const node is added to a graph before
388 // concat node to ensure topological order.
389 std::vector<const Value*> inputs;
390 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
391 Value* value;
392 const auto status = reader->ReadValue(idx, &value);
393 if (status.ok()) {
394 inputs.push_back(value);
395 } else {
396 TensorFloat32 tensor;
397 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
398 Value* value;
399 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
400 inputs.push_back(value);
401 }
402 }
403
404 Node* node = graph->NewNode();
405 node->operation.type = ToString(OperationType::CONCAT);
406 RETURN_IF_ERROR(reader->AddOutputs(node));
407 for (const Value* input : inputs) {
408 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
409 }
410
411 std::vector<BHWC> input_shapes;
412 for (auto input : graph->FindInputs(node->id)) {
413 input_shapes.push_back(input->tensor.shape);
414 }
415 RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
416
417 // Guess axis.
418 BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
419 for (auto input : graph->FindInputs(node->id)) {
420 if (input->tensor.shape.h != output_shape.h) {
421 attr.axis = Axis::HEIGHT;
422 break;
423 }
424 if (input->tensor.shape.w != output_shape.w) {
425 attr.axis = Axis::WIDTH;
426 break;
427 }
428 if (input->tensor.shape.c != output_shape.c) {
429 attr.axis = Axis::CHANNELS;
430 break;
431 }
432 }
433 const TfLiteConcatenationParams* tf_options;
434 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
435 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
436 node->operation.attributes = attr;
437 return absl::OkStatus();
438 }
439
440 private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)441 absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
442 *axis = Axis::BATCH;
443 for (int i = 1; i < input_shapes.size(); i++) {
444 if (input_shapes[0].h != input_shapes[i].h &&
445 input_shapes[0].w != input_shapes[i].w &&
446 input_shapes[0].c != input_shapes[i].c) {
447 *axis = Axis::HEIGHT;
448 break;
449 }
450 }
451 if (*axis == Axis::BATCH) return absl::OkStatus();
452 for (int i = 1; i < input_shapes.size(); i++) {
453 if (input_shapes[0].b != input_shapes[i].b &&
454 input_shapes[0].w != input_shapes[i].w &&
455 input_shapes[0].c != input_shapes[i].c) {
456 *axis = Axis::WIDTH;
457 break;
458 }
459 }
460 if (*axis == Axis::HEIGHT) return absl::OkStatus();
461 for (int i = 1; i < input_shapes.size(); i++) {
462 if (input_shapes[0].b != input_shapes[i].b &&
463 input_shapes[0].h != input_shapes[i].h &&
464 input_shapes[0].c != input_shapes[i].c) {
465 *axis = Axis::CHANNELS;
466 break;
467 }
468 }
469 if (*axis == Axis::WIDTH) return absl::OkStatus();
470 for (int i = 1; i < input_shapes.size(); i++) {
471 if (input_shapes[0].b != input_shapes[i].b &&
472 input_shapes[0].w != input_shapes[i].w &&
473 input_shapes[0].h != input_shapes[i].h) {
474 return absl::UnimplementedError(
475 "Can concatenate tensors only by batch, height, width, or "
476 "channels.");
477 }
478 }
479 return absl::OkStatus();
480 }
481 };
482
483 class Conv2DOperationParser : public TFLiteOperationParser {
484 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)485 absl::Status IsSupported(const TfLiteContext* context,
486 const TfLiteNode* tflite_node,
487 const TfLiteRegistration* registration) final {
488 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5));
489 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
490 }
491
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)492 absl::Status Parse(const TfLiteNode* tflite_node,
493 const TfLiteRegistration* registration,
494 GraphFloat32* graph, ObjectReader* reader) final {
495 Node* node = graph->NewNode();
496 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
497 RETURN_IF_ERROR(reader->AddInput(node, 0));
498 RETURN_IF_ERROR(reader->AddOutputs(node));
499
500 Convolution2DAttributes attr;
501 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
502 if (runtime_inputs == 2) {
503 RETURN_IF_ERROR(reader->AddInput(node, 1));
504 } else { // runtime_inputs == 1;
505 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
506 }
507 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
508
509 const TfLiteConvParams* tf_options;
510 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
511 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
512 attr.dilations = HW(tf_options->dilation_height_factor,
513 tf_options->dilation_width_factor);
514 UpdatePadding(tf_options->padding,
515 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
516 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
517 node->operation.attributes = std::move(attr);
518 return absl::OkStatus();
519 }
520 };
521
522 // Doesn't have a kernel implementation.
523 class DensifyOperationParser : public TFLiteOperationParser {
524 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)525 absl::Status IsSupported(const TfLiteContext* context,
526 const TfLiteNode* tflite_node,
527 const TfLiteRegistration* registration) final {
528 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
529 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
530 }
531
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)532 absl::Status Parse(const TfLiteNode* tflite_node,
533 const TfLiteRegistration* registration,
534 GraphFloat32* graph, ObjectReader* reader) final {
535 Node* node = graph->NewNode();
536 node->operation.type = ToString(OperationType::DENSIFY);
537 const TfLiteTensor* const_tensor = reader->GetInputTensor(0);
538 if (!const_tensor->sparsity) {
539 return absl::InvalidArgumentError("Input tensor must be sparse.");
540 }
541 TensorFloat32 sparse_tensor;
542 RETURN_IF_ERROR(reader->ReadTensor(0, &sparse_tensor));
543 DensifyAttributes attributes;
544 attributes.tensor = std::move(sparse_tensor);
545 node->operation.attributes = attributes;
546 return reader->AddOutputs(node);
547 }
548 };
549
550 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
551 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)552 absl::Status IsSupported(const TfLiteContext* context,
553 const TfLiteNode* tflite_node,
554 const TfLiteRegistration* registration) final {
555 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
556 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
557 }
558
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)559 absl::Status Parse(const TfLiteNode* tflite_node,
560 const TfLiteRegistration* registration,
561 GraphFloat32* graph, ObjectReader* reader) final {
562 Node* node = graph->NewNode();
563 node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
564 RETURN_IF_ERROR(reader->AddInput(node, 0));
565 RETURN_IF_ERROR(reader->AddOutputs(node));
566
567 DepthwiseConvolution2DAttributes attr;
568 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
569 if (runtime_inputs == 2) {
570 RETURN_IF_ERROR(reader->AddInput(node, 1));
571 } else { // runtime_inputs == 1;
572 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
573 }
574 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
575 const TfLiteDepthwiseConvParams* tf_options;
576 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
577 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
578 attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
579 std::max(1, tf_options->dilation_width_factor));
580 UpdatePadding(tf_options->padding,
581 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
582 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
583 const int depth_multiplier = tf_options->depth_multiplier;
584 if (depth_multiplier != 1) {
585 const TfLiteTensor* input = reader->GetInputTensor(0);
586 const TfLiteTensor* filter = reader->GetInputTensor(1);
587 const TfLiteTensor* output = reader->GetOutputTensor(0);
588 TransposeWeights(input, filter, output, depth_multiplier, &attr);
589 }
590 node->operation.attributes = std::move(attr);
591 return absl::OkStatus();
592 }
593
594 private:
595 // TFLite CPU stores weights as:
596 // [1, kernel_height, kernel_width, input_depth * depth_multiplier]
597 // TFLite GPU stores weights as:
598 // [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)599 static void TransposeWeights(const TfLiteTensor* input,
600 const TfLiteTensor* filter,
601 const TfLiteTensor* output, int depth_multiplier,
602 DepthwiseConvolution2DAttributes* attr) {
603 const int input_depth = input->dims->data[3];
604 const int filter_height = filter->dims->data[1];
605 const int filter_width = filter->dims->data[2];
606 const int output_depth = output->dims->data[3];
607 Tensor<OHWI, DataType::FLOAT32> weights;
608 weights.id = attr->weights.id;
609 weights.shape =
610 OHWI(output_depth, filter_height, filter_width, input_depth);
611 weights.data.resize(weights.shape.DimensionsProduct());
612 float* dst = &weights.data[0];
613 for (int j = 0; j < output_depth; ++j) {
614 const float* src = attr->weights.data.data() + j;
615 for (int i = 0; i < filter_height * filter_width; ++i) {
616 *dst = *src;
617 dst++;
618 src += output_depth;
619 }
620 }
621 attr->weights = std::move(weights);
622 }
623 };
624
625 class DepthToSpaceOperationParser : public TFLiteOperationParser {
626 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)627 absl::Status IsSupported(const TfLiteContext* context,
628 const TfLiteNode* tflite_node,
629 const TfLiteRegistration* registration) final {
630 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
631 }
632
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)633 absl::Status Parse(const TfLiteNode* tflite_node,
634 const TfLiteRegistration* registration,
635 GraphFloat32* graph, ObjectReader* reader) final {
636 Node* node = graph->NewNode();
637 node->operation.type = ToString(OperationType::DEPTH_TO_SPACE);
638 RETURN_IF_ERROR(reader->AddInput(node, 0));
639 RETURN_IF_ERROR(reader->AddOutputs(node));
640 const TfLiteDepthToSpaceParams* tf_options;
641 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
642 SpaceToDepthAttributes attr;
643 attr.block_size = tf_options->block_size;
644 node->operation.attributes = attr;
645 return absl::OkStatus();
646 }
647 };
648
649 class DequantizeOperationParser : public TFLiteOperationParser {
650 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)651 absl::Status IsSupported(const TfLiteContext* context,
652 const TfLiteNode* tflite_node,
653 const TfLiteRegistration* registration) final {
654 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
655 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
656 }
657
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)658 absl::Status Parse(const TfLiteNode* tflite_node,
659 const TfLiteRegistration* registration,
660 GraphFloat32* graph, ObjectReader* reader) final {
661 // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
662 // with floating-point versions of the original tensors.
663 Node* node = graph->NewNode();
664 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
665 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
666 if (runtime_inputs == 1) {
667 // Non-constant dequantization.
668 RETURN_IF_ERROR(reader->AddInput(node, 0));
669 } else {
670 // TODO(b/181274192): Optimize out this constant dequantization from the
671 // graph later.
672 TensorFloat32 tensor;
673 RETURN_IF_ERROR(reader->ReadTensor(0, &tensor));
674 Value* value;
675 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
676 // Need to retain the quant params from the original constant input.
677 const TfLiteTensor* tflite_input = reader->GetInputTensor(0);
678 value->quant_params.emplace();
679 RETURN_IF_ERROR(
680 PopulateQuantParams(*tflite_input, &value->quant_params.value()));
681 RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id));
682 }
683 RETURN_IF_ERROR(reader->AddOutputs(node));
684
685 // Quantization attributes should already be present in the input tensor.
686 auto input_value = graph->FindInputs(node->id)[0];
687 if (!input_value->quant_params) {
688 if (runtime_inputs == 1) {
689 // DEQUANTIZE op is preceded by DENSIFY op and doesn't have any
690 // quantization params. The DEQUANTIZE op latter will be removed from
691 // the graph in `MergeDensify` graph transformation.
692 return absl::OkStatus();
693 }
694 return absl::InvalidArgumentError(
695 "Encountered Dequantize input with no quant params");
696 }
697 QuantizeAndDequantizeAttributes attr;
698 attr.min = input_value->quant_params.value().min;
699 attr.max = input_value->quant_params.value().max;
700 attr.scale = input_value->quant_params.value().scale;
701
702 node->operation.attributes = attr;
703 return absl::OkStatus();
704 }
705 };
706
707 class ElementwiseOperationParser : public TFLiteOperationParser {
708 public:
ElementwiseOperationParser(OperationType operation_type)709 explicit ElementwiseOperationParser(OperationType operation_type)
710 : operation_type_(operation_type) {}
711
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)712 absl::Status IsSupported(const TfLiteContext* context,
713 const TfLiteNode* tflite_node,
714 const TfLiteRegistration* registration) final {
715 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
716 if (IsLogicalOp(operation_type_)) {
717 TensorInfo output_tensor_info;
718 RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
719 &output_tensor_info));
720 if (output_tensor_info.producers.size() != 1 ||
721 output_tensor_info.consumers.size() != 1) {
722 return absl::UnavailableError("Not supported logical op case");
723 }
724 if (output_tensor_info.consumers[0].second->builtin_code ==
725 kTfLiteBuiltinCast) {
726 return absl::OkStatus();
727 } else {
728 return absl::UnimplementedError("Not supported logical op case.");
729 }
730 }
731 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
732 }
733
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)734 absl::Status Parse(const TfLiteNode* tflite_node,
735 const TfLiteRegistration* registration,
736 GraphFloat32* graph, ObjectReader* reader) final {
737 Node* node = graph->NewNode();
738 node->operation.type = ToString(operation_type_);
739
740 if (IsOneArgumentOperation()) {
741 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
742 /*runtime_inputs=*/1,
743 /*const_inputs=*/0,
744 /*outputs=*/1));
745
746 RETURN_IF_ERROR(reader->AddInput(node, 0));
747 } else if (IsTwoArgumentOperation() &&
748 reader
749 ->VerifyInputsConstsOutputs(tflite_node,
750 /*runtime_inputs=*/2,
751 /*const_inputs=*/0,
752 /*outputs=*/1)
753 .ok()) {
754 if (tflite_node->inputs->size != 2) {
755 return absl::InvalidArgumentError("Applies only two input tensors");
756 }
757 RETURN_IF_ERROR(reader->AddInput(node, 0));
758 RETURN_IF_ERROR(reader->AddInput(node, 1));
759
760 TfLiteFusedActivation activation = kTfLiteActNone;
761 switch (operation_type_) {
762 case OperationType::SUB: {
763 const TfLiteSubParams* tf_options;
764 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
765 activation = tf_options->activation;
766 }
767 break;
768 }
769 case OperationType::DIV: {
770 const TfLiteDivParams* tf_options;
771 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
772 activation = tf_options->activation;
773 }
774 break;
775 }
776 default:
777 // No activation expected.
778 activation = kTfLiteActNone;
779 }
780
781 if (activation) {
782 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
783 }
784 } else if (IsTwoArgumentOperationWithConst()) {
785 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
786 /*runtime_inputs=*/1,
787 /*const_inputs=*/1,
788 /*outputs=*/1));
789 ElementwiseAttributes attr;
790 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
791 attr.runtime_tensor_is_second =
792 IsConstantTensor(reader->GetInputTensor(0));
793 node->operation.attributes = std::move(attr);
794 } else {
795 return absl::InvalidArgumentError("Incorrect operation type passed");
796 }
797
798 return reader->AddOutputs(node);
799 }
800
801 private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const802 absl::Status GetActivation(const TfLiteNode* tflite_node,
803 TfLiteFusedActivation* activation) const {
804 if (operation_type_ == OperationType::DIV) {
805 const TfLiteDivParams* tf_options;
806 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
807 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
808 return absl::OkStatus();
809 }
810 if (operation_type_ == OperationType::SUB) {
811 const TfLiteSubParams* tf_options;
812 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
813 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
814 return absl::OkStatus();
815 }
816
817 // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
818 // TfLiteXxxParams.activation.
819 *activation = kTfLiteActNone;
820 return absl::OkStatus();
821 }
822
IsOneArgumentOperation() const823 bool IsOneArgumentOperation() const {
824 switch (operation_type_) {
825 case OperationType::ABS:
826 case OperationType::COPY:
827 case OperationType::COS:
828 case OperationType::ELU:
829 case OperationType::EXP:
830 case OperationType::FLOOR:
831 case OperationType::LOG:
832 case OperationType::NEG:
833 case OperationType::RSQRT:
834 case OperationType::SIGMOID:
835 case OperationType::SIN:
836 case OperationType::SQRT:
837 case OperationType::SQUARE:
838 case OperationType::TANH:
839 return true;
840 default:
841 return false;
842 }
843 }
844
IsTwoArgumentOperation() const845 bool IsTwoArgumentOperation() const {
846 switch (operation_type_) {
847 case OperationType::DIV:
848 case OperationType::EQUAL:
849 case OperationType::FLOOR_DIV:
850 case OperationType::FLOOR_MOD:
851 case OperationType::GREATER:
852 case OperationType::GREATER_EQUAL:
853 case OperationType::LESS:
854 case OperationType::LESS_EQUAL:
855 case OperationType::MAXIMUM:
856 case OperationType::MINIMUM:
857 case OperationType::NOT_EQUAL:
858 case OperationType::POW:
859 case OperationType::SQUARED_DIFF:
860 case OperationType::SUB:
861 return true;
862 default:
863 return false;
864 }
865 }
866
IsTwoArgumentOperationWithConst() const867 bool IsTwoArgumentOperationWithConst() const {
868 switch (operation_type_) {
869 case OperationType::DIV:
870 case OperationType::EQUAL:
871 case OperationType::FLOOR_DIV:
872 case OperationType::FLOOR_MOD:
873 case OperationType::GREATER:
874 case OperationType::GREATER_EQUAL:
875 case OperationType::LESS:
876 case OperationType::LESS_EQUAL:
877 case OperationType::MAXIMUM:
878 case OperationType::MINIMUM:
879 case OperationType::NOT_EQUAL:
880 case OperationType::POW:
881 case OperationType::SQUARED_DIFF:
882 case OperationType::SUB:
883 return true;
884 default:
885 return false;
886 }
887 }
888
889 OperationType operation_type_;
890 };
891
892 class FullyConnectedOperationParser : public TFLiteOperationParser {
893 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)894 absl::Status IsSupported(const TfLiteContext* context,
895 const TfLiteNode* tflite_node,
896 const TfLiteRegistration* registration) final {
897 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
898 // TODO(eignasheva): check input shape
899 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
900 }
901
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)902 absl::Status Parse(const TfLiteNode* tflite_node,
903 const TfLiteRegistration* registration,
904 GraphFloat32* graph, ObjectReader* reader) final {
905 const TfLiteFullyConnectedParams* tf_options;
906 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
907
908 if (reader->GetNumberOfRuntimeInputs() == 2) {
909 // Create Convolution2D, so as it supports runtime weights.
910 Node* node = graph->NewNode();
911 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
912 RETURN_IF_ERROR(reader->AddInput(node, 0));
913 RETURN_IF_ERROR(reader->AddInput(node, 1));
914 RETURN_IF_ERROR(reader->AddOutputs(node));
915
916 Convolution2DAttributes attr;
917 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
918
919 attr.strides = HW(1, 1);
920 attr.dilations = HW(1, 1);
921 attr.padding.appended = HW(0, 0);
922 attr.padding.prepended = HW(0, 0);
923 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
924 node->operation.attributes = std::move(attr);
925 return absl::OkStatus();
926 }
927 Node* node = graph->NewNode();
928 RETURN_IF_ERROR(reader->AddInput(node, 0));
929
930 if (tf_options->weights_format !=
931 kTfLiteFullyConnectedWeightsFormatDefault) {
932 return absl::UnimplementedError(
933 "Unsupported FullyConnected weights format.");
934 }
935
936 FullyConnectedAttributes attr;
937 RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
938 const int weights_width = attr.weights.shape.i;
939
940 auto input = graph->FindInputs(node->id)[0];
941 int batch_size = input->tensor.shape.b;
942 if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) {
943 return absl::UnimplementedError(
944 "Amount of input data should match weights width");
945 }
946
947 Node* conv = node;
948 if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
949 auto& reshape = node;
950 conv = graph->NewNode(); // reset conv pointer!
951 Value* reshaped_value = graph->NewValue();
952 reshaped_value->tensor.type = DataType::FLOAT32;
953 reshaped_value->tensor.shape =
954 BHWC(input->tensor.shape.b, 1, 1, weights_width);
955 RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
956 reshape->operation.type = ToString(OperationType::RESHAPE);
957 ReshapeAttributes attr;
958 attr.new_shape = reshaped_value->tensor.shape;
959 reshape->operation.attributes = attr;
960 RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
961 }
962
963 conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
964 conv->operation.attributes = std::move(attr);
965 absl::Status result = reader->AddOutputs(conv);
966 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
967
968 return result;
969 }
970 };
971
972 class HardSwishOperationParser : public TFLiteOperationParser {
973 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)974 absl::Status IsSupported(const TfLiteContext* context,
975 const TfLiteNode* tflite_node,
976 const TfLiteRegistration* registration) final {
977 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
978 }
979
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)980 absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
981 GraphFloat32* graph, ObjectReader* reader) final {
982 Node* node = graph->NewNode();
983 node->operation.type = ToString(OperationType::HARD_SWISH);
984 RETURN_IF_ERROR(reader->AddInput(node, 0));
985 return reader->AddOutputs(node);
986 }
987 };
988
989 // Basic LSTM Cell:
990 //
991 // 1name = name is at input index 1
992 // name1 = name is at output index 1
993 //
994 // 0input 1prev_activ
995 // \ /
996 // [[concat]]
997 // \
998 // concat_temp2 2weights 3biases
999 // \ / /
1000 // [[fully-connected]]
1001 // \
1002 // activ_temp3 4prev_state
1003 // \ /
1004 // [[LSTM]]
1005 // / \
1006 // new_state1 activation0
1007 //
1008 // For full LSTM cells, see this blog post:
1009 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
1010 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
1011 // described in that post, this code also adds the following optional features:
1012 // - Configurable activations (sigmoid or TANH)
1013 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
1014 // - Output projection:
1015 // https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
1016 // - Configurable clipping of cell state and output state.
1017 class LSTMOperationParser : public TFLiteOperationParser {
1018 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1019 absl::Status IsSupported(const TfLiteContext* context,
1020 const TfLiteNode* tflite_node,
1021 const TfLiteRegistration* registration) final {
1022 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
1023 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1024 }
1025
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1026 absl::Status Parse(const TfLiteNode* tflite_node,
1027 const TfLiteRegistration* registration,
1028 GraphFloat32* graph, ObjectReader* reader) final {
1029 const TfLiteLSTMParams* tf_options;
1030 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1031 switch (tf_options->kernel_type) {
1032 case kTfLiteLSTMFullKernel:
1033 return ParseFull(tflite_node, registration, graph, reader, tf_options);
1034 case kTfLiteLSTMBasicKernel:
1035 return ParseBasic(tflite_node, registration, graph, reader, tf_options);
1036 }
1037 }
1038
GetNewValueIdsForVariableInputNodes()1039 absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
1040 final {
1041 return new_variable_input_value_map_;
1042 }
1043
1044 private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1045 absl::Status ParseBasic(const TfLiteNode* tflite_node,
1046 const TfLiteRegistration* registration,
1047 GraphFloat32* graph, ObjectReader* reader,
1048 const TfLiteLSTMParams* tf_options) {
1049 if (tflite_node->inputs->size != 5) {
1050 return absl::InvalidArgumentError("LSTM should have 5 input tensors");
1051 }
1052 if (tflite_node->outputs->size != 4) {
1053 return absl::InvalidArgumentError("LSTM should have 4 output tensors");
1054 }
1055 RETURN_IF_ERROR(CheckBasicParameters(tf_options));
1056
1057 Node* concat_node = graph->NewNode();
1058 concat_node->operation.type = ToString(OperationType::CONCAT);
1059 ConcatAttributes concat_attr;
1060 concat_attr.axis = Axis::CHANNELS;
1061 concat_node->operation.attributes = concat_attr;
1062
1063 Node* fc_node = graph->NewNode();
1064 fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
1065 FullyConnectedAttributes fc_attr;
1066 RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
1067 fc_node->operation.attributes = std::move(fc_attr);
1068
1069 Node* lstm_node = graph->NewNode();
1070 lstm_node->operation.type = ToString(OperationType::LSTM);
1071 LstmAttributes lstm_attr;
1072 lstm_attr.kernel_type = LstmKernelType::BASIC;
1073 lstm_node->operation.attributes = lstm_attr;
1074
1075 Value* concat_temp;
1076 int concat_tensor_idx = tflite_node->outputs->data[2];
1077 RETURN_IF_ERROR(
1078 reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
1079 Value* activ_temp;
1080 int activ_tensor_idx = tflite_node->outputs->data[3];
1081 RETURN_IF_ERROR(
1082 reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
1083
1084 RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
1085 RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
1086 RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
1087
1088 RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
1089 RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
1090
1091 RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
1092 RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
1093 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
1094 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
1095
1096 return absl::OkStatus();
1097 }
1098
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1099 absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1100 if (tf_options->activation != kTfLiteActTanh) {
1101 return absl::UnimplementedError("Only TANH activation is supported.");
1102 }
1103 if (tf_options->cell_clip != 0.0f) {
1104 return absl::UnimplementedError("cell_clip is not supported.");
1105 }
1106 if (tf_options->proj_clip != 0.0f) {
1107 return absl::UnimplementedError("proj_clip is not supported.");
1108 }
1109 return absl::OkStatus();
1110 }
1111
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1112 absl::Status ParseFull(const TfLiteNode* tflite_node,
1113 const TfLiteRegistration* registration,
1114 GraphFloat32* graph, ObjectReader* reader,
1115 const TfLiteLSTMParams* tf_options) {
1116 // Invoke full LSTM parser
1117 RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1118 reader, tf_options,
1119 &new_variable_input_value_map_));
1120 return absl::OkStatus();
1121 }
1122
CheckFullParameters(const TfLiteLSTMParams * tf_options)1123 absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1124 if (tf_options->activation != kTfLiteActSigmoid &&
1125 tf_options->activation != kTfLiteActTanh) {
1126 return absl::UnimplementedError(
1127 "Only sigmoid or tanh activation is supported.");
1128 }
1129
1130 return absl::OkStatus();
1131 }
1132
1133 absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1134 };
1135
1136 class MulOperationParser : public TFLiteOperationParser {
1137 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1138 absl::Status IsSupported(const TfLiteContext* context,
1139 const TfLiteNode* tflite_node,
1140 const TfLiteRegistration* registration) final {
1141 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1142 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1143 }
1144
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1145 absl::Status Parse(const TfLiteNode* tflite_node,
1146 const TfLiteRegistration* registration,
1147 GraphFloat32* graph, ObjectReader* reader) final {
1148 const TfLiteTensor* input0 = reader->GetInputTensor(0);
1149 if (!input0) {
1150 return absl::InvalidArgumentError(
1151 "Couldn't get the 1st input tensor for MUL.");
1152 }
1153 const TfLiteTensor* input1 = reader->GetInputTensor(1);
1154 if (!input1) {
1155 return absl::InvalidArgumentError(
1156 "Couldn't get the 2nd input tensor for MUL.");
1157 }
1158 const bool constant_tensor0 = IsConstantTensor(input0);
1159 const bool constant_tensor1 = IsConstantTensor(input1);
1160 if (constant_tensor0 && constant_tensor1) {
1161 return absl::InvalidArgumentError("No runtime input tensors for MUL.");
1162 }
1163 const bool runtime_tensor0 = !constant_tensor0;
1164 const bool runtime_tensor1 = !constant_tensor1;
1165
1166 Node* node = graph->NewNode();
1167 node->operation.type = ToString(OperationType::MUL);
1168 RETURN_IF_ERROR(reader->AddOutputs(node));
1169
1170 // Determine runtime/constant tensors.
1171 if (runtime_tensor0 && runtime_tensor1) {
1172 if (input0 == input1) {
1173 // replace MUL(A, A) with POW(A, 2.0)
1174 // TODO(b/166831113): Support the same inputs for operations.
1175 node->operation.type = ToString(OperationType::POW);
1176 ElementwiseAttributes attr;
1177 attr.param = 2.0f;
1178 node->operation.attributes = std::move(attr);
1179 return reader->AddInput(node, 0);
1180 }
1181
1182 // The "larger" input tensor must be bound to 1st input and the "smaller"
1183 // input tensor must be bound to 2nd input.
1184 BHWC shape0;
1185 RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1186 BHWC shape1;
1187 RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1188 int input_tensor0 = 0;
1189 int input_tensor1 = 1;
1190 if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1191 shape0.c == shape1.c) {
1192 input_tensor0 = 1;
1193 input_tensor1 = 0;
1194 }
1195 RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1196 RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1197 } else {
1198 ElementwiseAttributes attr;
1199 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1200 node->operation.attributes = std::move(attr);
1201 }
1202
1203 const TfLiteMulParams* tf_options;
1204 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1205 return MaybeFuseActivation(tf_options->activation, graph, node);
1206 }
1207 };
1208
1209 class PackOperationParser : public TFLiteOperationParser {
1210 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1211 absl::Status IsSupported(const TfLiteContext* context,
1212 const TfLiteNode* tflite_node,
1213 const TfLiteRegistration* registration) final {
1214 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1215 }
1216
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1217 absl::Status Parse(const TfLiteNode* tflite_node,
1218 const TfLiteRegistration* registration,
1219 GraphFloat32* graph, ObjectReader* reader) final {
1220 if (tflite_node->inputs->size == 1) {
1221 // Pack with single input can be replaced with Reshape
1222 Node* node = graph->NewNode();
1223 node->operation.type = ToString(OperationType::RESHAPE);
1224 RETURN_IF_ERROR(reader->AddInput(node, 0));
1225 RETURN_IF_ERROR(reader->AddOutputs(node));
1226 // New shape comes from output shape.
1227 ReshapeAttributes attr;
1228 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1229 node->operation.attributes = attr;
1230 return absl::OkStatus();
1231 } else {
1232 // Pack with few inputs can be replaced with Concat
1233 const TfLitePackParams* tf_options;
1234 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1235
1236 // Read inputs first to make sure const node is added to a graph before
1237 // concat node to ensure topological order.
1238 std::vector<const Value*> inputs;
1239 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1240 Value* value;
1241 const auto status = reader->ReadValue(idx, &value);
1242 if (status.ok()) {
1243 inputs.push_back(value);
1244 } else {
1245 TensorFloat32 tensor;
1246 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1247 Value* value;
1248 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1249 inputs.push_back(value);
1250 }
1251 }
1252
1253 Node* node = graph->NewNode();
1254 node->operation.type = ToString(OperationType::CONCAT);
1255 RETURN_IF_ERROR(reader->AddOutputs(node));
1256 for (const Value* input : inputs) {
1257 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1258 }
1259 const TfLiteTensor* output = reader->GetOutputTensor(0);
1260 ConcatAttributes attr;
1261 RETURN_IF_ERROR(
1262 ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1263 node->operation.attributes = attr;
1264 return absl::OkStatus();
1265 }
1266 }
1267 };
1268
1269 class PReLUOperationParser : public TFLiteOperationParser {
1270 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1271 absl::Status IsSupported(const TfLiteContext* context,
1272 const TfLiteNode* tflite_node,
1273 const TfLiteRegistration* registration) final {
1274 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1275 // TODO(eignasheva): add params check
1276 return absl::OkStatus();
1277 }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1278 absl::Status Parse(const TfLiteNode* tflite_node,
1279 const TfLiteRegistration* registration,
1280 GraphFloat32* graph, ObjectReader* reader) final {
1281 Node* node = graph->NewNode();
1282 node->operation.type = ToString(OperationType::PRELU);
1283 RETURN_IF_ERROR(reader->AddInput(node, 0));
1284 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1285
1286 PReLUAttributes attr;
1287 Tensor<Linear, DataType::FLOAT32> linear_alpha;
1288 absl::Status status = reader->ReadTensor(1, &linear_alpha);
1289 if (status.ok()) {
1290 if (linear_alpha.shape.v != input_shape.c) {
1291 return absl::InvalidArgumentError(
1292 "Linear alpha shape does not match the number of input channels.");
1293 }
1294 attr.alpha = std::move(linear_alpha);
1295 } else {
1296 Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1297 RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1298 if (hwc_alpha.shape.h != input_shape.h ||
1299 hwc_alpha.shape.w != input_shape.w ||
1300 hwc_alpha.shape.c != input_shape.c) {
1301 return absl::InvalidArgumentError(
1302 "Alpha shape does not match input shape.");
1303 }
1304 attr.alpha = std::move(hwc_alpha);
1305 }
1306 node->operation.attributes = std::move(attr);
1307 return reader->AddOutputs(node);
1308 }
1309 };
1310
1311 class PadOperationParser : public TFLiteOperationParser {
1312 public:
PadOperationParser(bool mirror_pad)1313 explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1314
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1315 absl::Status IsSupported(const TfLiteContext* context,
1316 const TfLiteNode* tflite_node,
1317 const TfLiteRegistration* registration) final {
1318 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1319 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1320 }
1321
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1322 absl::Status Parse(const TfLiteNode* tflite_node,
1323 const TfLiteRegistration* registration,
1324 GraphFloat32* graph, ObjectReader* reader) final {
1325 Node* node = graph->NewNode();
1326 node->operation.type = ToString(OperationType::PAD);
1327 RETURN_IF_ERROR(reader->AddInput(node, 0));
1328 RETURN_IF_ERROR(reader->AddOutputs(node));
1329
1330 PadAttributes attr;
1331 if (mirror_pad_) {
1332 attr.type = PaddingContentType::REFLECT;
1333 } else /*zero pad*/ {
1334 attr.type = PaddingContentType::ZEROS;
1335 }
1336
1337 Tensor<HW, DataType::INT32> paddings;
1338 RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1339
1340 if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1341 // 4x2 tensor with paddings.
1342 attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1343 paddings.data[4], paddings.data[6]);
1344 attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1345 paddings.data[7]);
1346 } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1347 // 3x2 tensor with paddings.
1348 attr.prepended =
1349 BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1350 attr.appended =
1351 BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1352 } else {
1353 // It shouldn't fail here since it's checked at IsSupported().
1354 return absl::InvalidArgumentError(
1355 "Paddings tensor has unexpected shape.");
1356 }
1357 node->operation.attributes = attr;
1358 return absl::OkStatus();
1359 }
1360
1361 private:
1362 bool mirror_pad_ = false;
1363 };
1364
1365 class Pooling2DOperationParser : public TFLiteOperationParser {
1366 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1367 absl::Status IsSupported(const TfLiteContext* context,
1368 const TfLiteNode* tflite_node,
1369 const TfLiteRegistration* registration) final {
1370 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1371 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1372 }
1373
1374 public:
Pooling2DOperationParser(PoolingType type)1375 explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1376
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1377 absl::Status Parse(const TfLiteNode* tflite_node,
1378 const TfLiteRegistration* registration,
1379 GraphFloat32* graph, ObjectReader* reader) final {
1380 Node* node = graph->NewNode();
1381 node->operation.type = ToString(OperationType::POOLING_2D);
1382 RETURN_IF_ERROR(reader->AddInput(node, 0));
1383 RETURN_IF_ERROR(reader->AddOutput(node, 0));
1384
1385 Pooling2DAttributes attr;
1386 attr.type = type_;
1387
1388 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1389
1390 // Check whether there are custom options encoded. It happens if operation
1391 // is MaxPoolingWithArgmax2D. There is no way to read
1392 // tflite_node->builtin_code, so, simply check whether custom data is
1393 // available.
1394 const TfLitePoolParams* tf_options;
1395 if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
1396 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1397 }
1398
1399 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1400 // Second output is optional. It is not required, it but must be added after
1401 // MaybeAddFusedActivation function is called
1402 reader->AddOutput(node, 1).IgnoreError();
1403
1404 // First output is the result of pooling operation, while second output is
1405 // indices used for pooling.
1406 auto outputs = graph->FindOutputs(node->id);
1407 attr.output_indices = outputs.size() == 2;
1408 if (attr.output_indices) {
1409 // Fix data type for output indices. In the model it is set as float32.
1410 outputs[1]->tensor.type = DataType::INT32;
1411 }
1412 RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1413 node->operation.attributes = attr;
1414 return absl::OkStatus();
1415 }
1416
1417 private:
1418 const PoolingType type_;
1419 };
1420
1421 class ReduceOperationParser : public TFLiteOperationParser {
1422 public:
ReduceOperationParser(OperationType operation_type)1423 explicit ReduceOperationParser(OperationType operation_type)
1424 : operation_type_(operation_type) {}
1425
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1426 absl::Status IsSupported(const TfLiteContext* context,
1427 const TfLiteNode* tflite_node,
1428 const TfLiteRegistration* registration) final {
1429 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1430 }
1431
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1432 absl::Status Parse(const TfLiteNode* tflite_node,
1433 const TfLiteRegistration* registration,
1434 GraphFloat32* graph, ObjectReader* reader) final {
1435 Node* node = graph->NewNode();
1436 node->operation.type = ToString(operation_type_);
1437 RETURN_IF_ERROR(reader->AddInput(node, 0));
1438 RETURN_IF_ERROR(reader->AddOutputs(node));
1439
1440 const TfLiteReducerParams* tf_options;
1441 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1442
1443 ReduceAttributes attr;
1444 const TfLiteTensor* input = reader->GetInputTensor(0);
1445 const TfLiteTensor* axes = reader->GetInputTensor(1);
1446 for (int i = 0; i < NumElements(axes->dims); i++) {
1447 Axis axis;
1448 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1449 attr.dims.insert(axis);
1450 }
1451 node->operation.attributes = attr;
1452 return absl::OkStatus();
1453 }
1454
1455 private:
1456 const OperationType operation_type_;
1457 };
1458
1459 class QuantizeOperationParser : public TFLiteOperationParser {
1460 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1461 absl::Status IsSupported(const TfLiteContext* context,
1462 const TfLiteNode* tflite_node,
1463 const TfLiteRegistration* registration) final {
1464 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1465 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1466 }
1467
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1468 absl::Status Parse(const TfLiteNode* tflite_node,
1469 const TfLiteRegistration* registration,
1470 GraphFloat32* graph, ObjectReader* reader) final {
1471 // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1472 // with floating-point versions of the original tensors.
1473 Node* node = graph->NewNode();
1474 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1475 RETURN_IF_ERROR(reader->AddInput(node, 0));
1476 RETURN_IF_ERROR(reader->AddOutputs(node));
1477
1478 // Quantization attributes should already be present in the output tensor.
1479 auto output_value = graph->FindOutputs(node->id)[0];
1480 if (!output_value->quant_params) {
1481 return absl::InvalidArgumentError(
1482 "Encountered Quantize output with no quant params");
1483 }
1484 QuantizeAndDequantizeAttributes attr;
1485 attr.min = output_value->quant_params.value().min;
1486 attr.max = output_value->quant_params.value().max;
1487 attr.scale = output_value->quant_params.value().scale;
1488
1489 node->operation.attributes = attr;
1490 return absl::OkStatus();
1491 }
1492 };
1493
1494 class ReLUOperationParser : public TFLiteOperationParser {
1495 public:
ReLUOperationParser(int clip)1496 explicit ReLUOperationParser(int clip) : clip_(clip) {}
1497
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1498 absl::Status IsSupported(const TfLiteContext* context,
1499 const TfLiteNode* tflite_node,
1500 const TfLiteRegistration* registration) final {
1501 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1502 return absl::OkStatus();
1503 }
1504
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1505 absl::Status Parse(const TfLiteNode* tflite_node,
1506 const TfLiteRegistration* registration,
1507 GraphFloat32* graph, ObjectReader* reader) final {
1508 Node* node = graph->NewNode();
1509 node->operation.type = ToString(OperationType::RELU);
1510 RETURN_IF_ERROR(reader->AddInput(node, 0));
1511
1512 ReLUAttributes attr;
1513 const TfLiteLeakyReluParams* tf_options;
1514 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1515 attr.alpha = status.ok() ? tf_options->alpha : 0;
1516 attr.clip = clip_;
1517 node->operation.attributes = attr;
1518 return reader->AddOutputs(node);
1519 }
1520
1521 private:
1522 const int clip_;
1523 };
1524
1525 class ResamplerOperationParser : public TFLiteOperationParser {
1526 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1527 absl::Status IsSupported(const TfLiteContext* context,
1528 const TfLiteNode* tflite_node,
1529 const TfLiteRegistration* registration) final {
1530 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1531 }
1532
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1533 absl::Status Parse(const TfLiteNode* tflite_node,
1534 const TfLiteRegistration* registration,
1535 GraphFloat32* graph, ObjectReader* reader) final {
1536 Node* node = graph->NewNode();
1537 RETURN_IF_ERROR(reader->AddInput(node, 0)); // src
1538 RETURN_IF_ERROR(reader->AddInput(node, 1)); // warp
1539 RETURN_IF_ERROR(reader->AddOutputs(node));
1540
1541 node->operation.type = ToString(OperationType::RESAMPLER);
1542
1543 auto src_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1544 auto warp_shape = graph->FindInputs(node->id)[1]->tensor.shape;
1545
1546 auto output_value = graph->FindOutputs(node->id)[0];
1547 output_value->tensor.shape =
1548 BHWC(src_shape.b, warp_shape.h, warp_shape.w, src_shape.c);
1549 return absl::OkStatus();
1550 }
1551 };
1552
1553 class ReshapeOperationParser : public TFLiteOperationParser {
1554 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1555 absl::Status IsSupported(const TfLiteContext* context,
1556 const TfLiteNode* tflite_node,
1557 const TfLiteRegistration* registration) final {
1558 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1559 // TODO(eignasheva): add shape checking
1560 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1561 }
1562
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1563 absl::Status Parse(const TfLiteNode* tflite_node,
1564 const TfLiteRegistration* registration,
1565 GraphFloat32* graph, ObjectReader* reader) final {
1566 Node* node = graph->NewNode();
1567 node->operation.type = ToString(OperationType::RESHAPE);
1568 RETURN_IF_ERROR(reader->AddInput(node, 0));
1569 RETURN_IF_ERROR(reader->AddOutputs(node));
1570 // Here we may have extra inputs. Other tensors were supposed to
1571 // define new shape, but in TFLite these are ignored.
1572 // TODO(akulik): check that shapes match?
1573
1574 // New shape comes from output shape.
1575 ReshapeAttributes attr;
1576 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1577 node->operation.attributes = attr;
1578 return absl::OkStatus();
1579 }
1580 };
1581
1582 class Resize2DOperationParser : public TFLiteOperationParser {
1583 public:
Resize2DOperationParser(SamplingType sampling_type)1584 explicit Resize2DOperationParser(SamplingType sampling_type)
1585 : sampling_type_(sampling_type) {}
1586
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1587 absl::Status IsSupported(const TfLiteContext* context,
1588 const TfLiteNode* tflite_node,
1589 const TfLiteRegistration* registration) final {
1590 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1591 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1592 }
1593
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1594 absl::Status Parse(const TfLiteNode* tflite_node,
1595 const TfLiteRegistration* registration,
1596 GraphFloat32* graph, ObjectReader* reader) final {
1597 Node* node = graph->NewNode();
1598 node->operation.type = ToString(OperationType::RESIZE);
1599 RETURN_IF_ERROR(reader->AddInput(node, 0));
1600 RETURN_IF_ERROR(reader->AddOutputs(node));
1601 // Here we may have extra inputs. Other tensors were supposed to
1602 // define new shape, but in TFLite these are ignored.
1603
1604 Resize2DAttributes attr;
1605 RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1606 RETURN_IF_ERROR(
1607 GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1608 attr.type = sampling_type_;
1609 attr.new_shape.CopyAllDefinedAxis(
1610 graph->FindOutputs(node->id)[0]->tensor.shape);
1611 node->operation.attributes = attr;
1612 return absl::OkStatus();
1613 }
1614
1615 private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1616 absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1617 bool* align_corners) {
1618 switch (sampling_type_) {
1619 case SamplingType::BILINEAR:
1620 return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1621 tflite_node, align_corners);
1622 case SamplingType::NEAREST:
1623 return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1624 tflite_node, align_corners);
1625 case SamplingType::UNKNOWN:
1626 return absl::InternalError("Sampling type is not specified");
1627 }
1628 return absl::OkStatus();
1629 }
1630
1631 template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1632 absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1633 bool* align_corners) {
1634 const T* tf_options;
1635 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1636 *align_corners = tf_options->align_corners;
1637 return absl::OkStatus();
1638 }
1639
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1640 absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1641 bool* half_pixel_centers) {
1642 if (sampling_type_ == SamplingType::BILINEAR) {
1643 const TfLiteResizeBilinearParams* tf_options;
1644 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1645 if (tf_options->align_corners && tf_options->half_pixel_centers) {
1646 return absl::InternalError(
1647 "If half_pixel_centers is True, align_corners must be False.");
1648 }
1649 *half_pixel_centers = tf_options->half_pixel_centers;
1650 } else {
1651 const TfLiteResizeNearestNeighborParams* tf_options;
1652 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1653 *half_pixel_centers = tf_options->half_pixel_centers;
1654 }
1655 return absl::OkStatus();
1656 }
1657
1658 SamplingType sampling_type_ = SamplingType::UNKNOWN;
1659 };
1660
1661 class SliceOperationParser : public TFLiteOperationParser {
1662 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1663 absl::Status IsSupported(const TfLiteContext* context,
1664 const TfLiteNode* tflite_node,
1665 const TfLiteRegistration* registration) final {
1666 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1667 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1668 }
1669
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1670 absl::Status Parse(const TfLiteNode* tflite_node,
1671 const TfLiteRegistration* registration,
1672 GraphFloat32* graph, ObjectReader* reader) final {
1673 Node* node = graph->NewNode();
1674 node->operation.type = ToString(OperationType::SLICE);
1675 RETURN_IF_ERROR(reader->AddOutputs(node));
1676 Value* input;
1677 RETURN_IF_ERROR(reader->ReadValue(0, &input));
1678 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1679
1680 const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
1681 const int input_dims = tfl_input->dims->size;
1682
1683 SliceAttributes attr;
1684 attr.strides = BHWC(1, 1, 1, 1);
1685 Tensor<Linear, DataType::INT32> starts, sizes;
1686 RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
1687 RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
1688 if (starts.data.size() != sizes.data.size()) {
1689 return absl::InvalidArgumentError("Starts amount != sizes amount.");
1690 }
1691 BHWC bhwc_starts(0, 0, 0, 0);
1692 BHWC bhwc_sizes = input->tensor.shape;
1693 if (input_dims == 4) {
1694 // input in BHWC layout
1695 if (starts.data.size() == 4) {
1696 bhwc_starts.b = starts.data[0];
1697 bhwc_starts.h = starts.data[1];
1698 bhwc_starts.w = starts.data[2];
1699 bhwc_starts.c = starts.data[3];
1700 bhwc_sizes.b = sizes.data[0];
1701 bhwc_sizes.h = sizes.data[1];
1702 bhwc_sizes.w = sizes.data[2];
1703 bhwc_sizes.c = sizes.data[3];
1704 } else if (starts.data.size() == 3) {
1705 // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
1706 bhwc_starts.h = starts.data[0];
1707 bhwc_starts.w = starts.data[1];
1708 bhwc_starts.c = starts.data[2];
1709 bhwc_sizes.h = sizes.data[0];
1710 bhwc_sizes.w = sizes.data[1];
1711 bhwc_sizes.c = sizes.data[2];
1712 } else {
1713 return absl::UnimplementedError(
1714 "Slicing is supported for 3 or 4 dimensional tensors only.");
1715 }
1716 } else if (input_dims == 3) {
1717 // input in BWC layout
1718 if (starts.data.size() == 3) {
1719 bhwc_starts.b = starts.data[0];
1720 bhwc_starts.w = starts.data[1];
1721 bhwc_starts.c = starts.data[2];
1722 bhwc_sizes.b = sizes.data[0];
1723 bhwc_sizes.w = sizes.data[1];
1724 bhwc_sizes.c = sizes.data[2];
1725 } else {
1726 return absl::UnimplementedError(
1727 "Slicing is supported for 3 or 4 dimensional tensors only.");
1728 }
1729 } else {
1730 return absl::UnimplementedError(
1731 "Slicing is supported for 3 or 4 dimensional tensors only.");
1732 }
1733 const auto& in_shape = input->tensor.shape;
1734 if (bhwc_sizes.b == -1) {
1735 bhwc_sizes.b = in_shape.b - bhwc_starts.b;
1736 }
1737 if (bhwc_sizes.h == -1) {
1738 bhwc_sizes.h = in_shape.h - bhwc_starts.h;
1739 }
1740 if (bhwc_sizes.w == -1) {
1741 bhwc_sizes.w = in_shape.w - bhwc_starts.w;
1742 }
1743 if (bhwc_sizes.c == -1) {
1744 bhwc_sizes.c = in_shape.c - bhwc_starts.c;
1745 }
1746 attr.starts = bhwc_starts;
1747 attr.ends =
1748 BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
1749 bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
1750 RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
1751
1752 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1753 if ((attr.ends.b - attr.starts.b) != out_shape.b) {
1754 return absl::UnimplementedError("Output batch don't match");
1755 }
1756 if ((attr.ends.h - attr.starts.h) != out_shape.h) {
1757 return absl::UnimplementedError("Output height doesn't match");
1758 }
1759 if ((attr.ends.w - attr.starts.w) != out_shape.w) {
1760 return absl::UnimplementedError("Output width doesn't match");
1761 }
1762 if ((attr.ends.c - attr.starts.c) != out_shape.c) {
1763 return absl::UnimplementedError("Output channels don't match");
1764 }
1765 node->operation.attributes = attr;
1766 return absl::OkStatus();
1767 }
1768
1769 private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1770 absl::Status UpdateIfNegative(const BHWC& input_shape,
1771 SliceAttributes* attr) {
1772 if (attr->ends.h < 0) {
1773 attr->ends.h = input_shape.h + attr->ends.h;
1774 }
1775 if (attr->ends.w < 0) {
1776 attr->ends.w = input_shape.w + attr->ends.w;
1777 }
1778 if (attr->ends.c < 0) {
1779 attr->ends.c = input_shape.c + attr->ends.c;
1780 }
1781 if (attr->ends.b < 0) {
1782 attr->ends.b = input_shape.b + attr->ends.b;
1783 }
1784 return absl::OkStatus();
1785 }
1786 };
1787
1788 class SoftmaxOperationParser : public TFLiteOperationParser {
1789 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1790 absl::Status IsSupported(const TfLiteContext* context,
1791 const TfLiteNode* tflite_node,
1792 const TfLiteRegistration* registration) final {
1793 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1794 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1795 }
1796
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1797 absl::Status Parse(const TfLiteNode* tflite_node,
1798 const TfLiteRegistration* registration,
1799 GraphFloat32* graph, ObjectReader* reader) final {
1800 Node* node = graph->NewNode();
1801 node->operation.type = ToString(OperationType::SOFTMAX);
1802 RETURN_IF_ERROR(reader->AddInput(node, 0));
1803 RETURN_IF_ERROR(reader->AddOutputs(node));
1804
1805 const TfLiteSoftmaxParams* tf_options;
1806 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1807 if (tf_options->beta != 1) {
1808 // there is multiply by scalar operation fused in softmax. Make a layer
1809 // out of it before softmax.
1810 return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1811 // auto mul_node = reader->NewPassthroughNode(node);
1812 // mul_node->operation.type = ToString(OperationType::MUL);
1813 }
1814 SoftmaxAttributes attr;
1815 attr.axis = Axis::CHANNELS; // always by channels
1816 node->operation.attributes = attr;
1817 return absl::OkStatus();
1818 }
1819 };
1820
1821 class SpaceToDepthOperationParser : public TFLiteOperationParser {
1822 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1823 absl::Status IsSupported(const TfLiteContext* context,
1824 const TfLiteNode* tflite_node,
1825 const TfLiteRegistration* registration) final {
1826 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1827 // TODO(impjdi): Dims check.
1828 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1829 }
1830
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1831 absl::Status Parse(const TfLiteNode* tflite_node,
1832 const TfLiteRegistration* registration,
1833 GraphFloat32* graph, ObjectReader* reader) final {
1834 Node* node = graph->NewNode();
1835 node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
1836 RETURN_IF_ERROR(reader->AddInput(node, 0));
1837 RETURN_IF_ERROR(reader->AddOutputs(node));
1838 const TfLiteSpaceToDepthParams* tf_options;
1839 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1840 SpaceToDepthAttributes attr;
1841 attr.block_size = tf_options->block_size;
1842 node->operation.attributes = attr;
1843 return absl::OkStatus();
1844 }
1845 };
1846
1847 class SplitOperationParser : public TFLiteOperationParser {
1848 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1849 absl::Status IsSupported(const TfLiteContext* context,
1850 const TfLiteNode* tflite_node,
1851 const TfLiteRegistration* registration) final {
1852 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1853 }
1854
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1855 absl::Status Parse(const TfLiteNode* tflite_node,
1856 const TfLiteRegistration* registration,
1857 GraphFloat32* graph, ObjectReader* reader) final {
1858 const TfLiteSplitParams* split_params;
1859 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1860 if (split_params->num_splits == 1) {
1861 // Adding Identity reshape that will be removed.
1862 Node* node = graph->NewNode();
1863 node->operation.type = ToString(OperationType::RESHAPE);
1864 RETURN_IF_ERROR(reader->AddInput(node, 1));
1865 RETURN_IF_ERROR(reader->AddOutputs(node));
1866 // New shape comes from output shape.
1867 ReshapeAttributes attr;
1868 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1869 node->operation.attributes = attr;
1870 return absl::OkStatus();
1871 }
1872 const TfLiteTensor* input = reader->GetInputTensor(1);
1873 const TfLiteTensor* axis_tensor = reader->GetInputTensor(0);
1874 SplitAttributes attr;
1875 RETURN_IF_ERROR(
1876 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1877
1878 Node* node = graph->NewNode();
1879 node->operation.type = ToString(OperationType::SPLIT);
1880 node->operation.attributes = attr;
1881 RETURN_IF_ERROR(reader->AddInput(node, 1));
1882 for (int i = 0; i < tflite_node->outputs->size; ++i) {
1883 RETURN_IF_ERROR(reader->AddOutput(node, i));
1884 }
1885 return absl::OkStatus();
1886 }
1887 };
1888
1889 class SplitVOperationParser : public TFLiteOperationParser {
1890 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1891 absl::Status IsSupported(const TfLiteContext* context,
1892 const TfLiteNode* tflite_node,
1893 const TfLiteRegistration* registration) final {
1894 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1895 }
1896
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1897 absl::Status Parse(const TfLiteNode* tflite_node,
1898 const TfLiteRegistration* registration,
1899 GraphFloat32* graph, ObjectReader* reader) final {
1900 const TfLiteSplitVParams* split_params;
1901 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1902 if (split_params->num_splits == 1) {
1903 // Adding Identity reshape that will be removed.
1904 Node* node = graph->NewNode();
1905 node->operation.type = ToString(OperationType::RESHAPE);
1906 RETURN_IF_ERROR(reader->AddInput(node, 0));
1907 RETURN_IF_ERROR(reader->AddOutputs(node));
1908 // New shape comes from output shape.
1909 ReshapeAttributes attr;
1910 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1911 node->operation.attributes = attr;
1912 return absl::OkStatus();
1913 }
1914 const TfLiteTensor* input = reader->GetInputTensor(0);
1915 const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
1916 SplitAttributes attr;
1917 RETURN_IF_ERROR(
1918 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1919
1920 Node* node = graph->NewNode();
1921 node->operation.type = ToString(OperationType::SPLIT);
1922 node->operation.attributes = attr;
1923 RETURN_IF_ERROR(reader->AddInput(node, 0));
1924 for (int i = 0; i < tflite_node->outputs->size; ++i) {
1925 RETURN_IF_ERROR(reader->AddOutput(node, i));
1926 }
1927 return absl::OkStatus();
1928 }
1929 };
1930
1931 class StridedSliceOperationParser : public TFLiteOperationParser {
1932 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1933 absl::Status IsSupported(const TfLiteContext* context,
1934 const TfLiteNode* tflite_node,
1935 const TfLiteRegistration* registration) final {
1936 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1937 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1938 }
1939
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1940 absl::Status Parse(const TfLiteNode* tflite_node,
1941 const TfLiteRegistration* registration,
1942 GraphFloat32* graph, ObjectReader* reader) final {
1943 Node* node = graph->NewNode();
1944 node->operation.type = ToString(OperationType::SLICE);
1945 RETURN_IF_ERROR(reader->AddOutputs(node));
1946 Value* input;
1947 RETURN_IF_ERROR(reader->ReadValue(0, &input));
1948 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1949
1950 Tensor<Linear, DataType::INT32> tmp;
1951 RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
1952
1953 bool read_without_batch = tmp.data.size() == 3;
1954 bool read_with_batch = tmp.data.size() == 4;
1955 if (!read_without_batch && !read_with_batch) {
1956 // Error: Must be catched in IsSupported()
1957 return absl::UnimplementedError(
1958 "Slicing is supported for 3 or 4 dimensional tensors only.");
1959 }
1960
1961 const TfLiteStridedSliceParams* tf_options;
1962 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1963 RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1964
1965 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1966
1967 SliceAttributes attr;
1968 if (read_without_batch) {
1969 RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
1970 input->tensor.shape, &attr));
1971 }
1972 if (read_with_batch) {
1973 RETURN_IF_ERROR(
1974 ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
1975 }
1976 if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
1977 attr.strides.c == 0) {
1978 return absl::InvalidArgumentError("stride values must be non-zero");
1979 }
1980 if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
1981 attr.strides.c < 0) {
1982 return absl::UnimplementedError("Reverse slices are not supported.");
1983 }
1984 if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
1985 out_shape.b) {
1986 return absl::UnimplementedError("Output batch don't match");
1987 }
1988 if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
1989 out_shape.h) {
1990 return absl::UnimplementedError("Output height doesn't match");
1991 }
1992 if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
1993 out_shape.w) {
1994 return absl::UnimplementedError("Output width doesn't match");
1995 }
1996 if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
1997 out_shape.c) {
1998 return absl::UnimplementedError("Output channels don't match");
1999 }
2000 node->operation.attributes = attr;
2001 return absl::OkStatus();
2002 }
2003
2004 private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)2005 absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
2006 const BHWC& input_shape, int ignore_b,
2007 int ignore_h, int ignore_w, int ignore_c,
2008 SliceAttributes* attr) {
2009 if (tf_options->begin_mask & ignore_h) {
2010 attr->starts.h = 0;
2011 }
2012 if (tf_options->begin_mask & ignore_w) {
2013 attr->starts.w = 0;
2014 }
2015 if (tf_options->begin_mask & ignore_c) {
2016 attr->starts.c = 0;
2017 }
2018 if (tf_options->begin_mask & ignore_b) {
2019 attr->starts.b = 0;
2020 }
2021
2022 if (tf_options->end_mask & ignore_h) {
2023 attr->ends.h = input_shape.h;
2024 }
2025 if (tf_options->end_mask & ignore_w) {
2026 attr->ends.w = input_shape.w;
2027 }
2028 if (tf_options->end_mask & ignore_c) {
2029 attr->ends.c = input_shape.c;
2030 }
2031 if (tf_options->end_mask & ignore_b) {
2032 attr->ends.b = input_shape.b;
2033 }
2034 return absl::OkStatus();
2035 }
2036
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2037 absl::Status UpdateIfNegative(const BHWC& input_shape,
2038 SliceAttributes* attr) {
2039 if (attr->ends.h < 0) {
2040 attr->ends.h = input_shape.h + attr->ends.h;
2041 }
2042 if (attr->ends.w < 0) {
2043 attr->ends.w = input_shape.w + attr->ends.w;
2044 }
2045 if (attr->ends.c < 0) {
2046 attr->ends.c = input_shape.c + attr->ends.c;
2047 }
2048 if (attr->ends.b < 0) {
2049 attr->ends.b = input_shape.b + attr->ends.b;
2050 }
2051
2052 if (attr->starts.h < 0) {
2053 attr->starts.h = input_shape.h + attr->starts.h;
2054 }
2055 if (attr->starts.w < 0) {
2056 attr->starts.w = input_shape.w + attr->starts.w;
2057 }
2058 if (attr->starts.c < 0) {
2059 attr->starts.c = input_shape.c + attr->starts.c;
2060 }
2061 if (attr->starts.b < 0) {
2062 attr->starts.b = input_shape.b + attr->starts.b;
2063 }
2064
2065 return absl::OkStatus();
2066 }
2067
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2068 absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2069 const TfLiteStridedSliceParams* tf_options,
2070 const BHWC& input_shape,
2071 SliceAttributes* attr) {
2072 auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2073 Tensor<Linear, DataType::INT32> t;
2074 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2075 *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2076 return absl::OkStatus();
2077 };
2078
2079 RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2080 RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2081 RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2082 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2083 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2084 return absl::OkStatus();
2085 }
2086
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2087 absl::Status ReadAttribsWithoutBatch(
2088 const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2089 const BHWC& input_shape, SliceAttributes* attr) {
2090 auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2091 Tensor<Linear, DataType::INT32> t;
2092 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2093 *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2094 return absl::OkStatus();
2095 };
2096
2097 RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2098 RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2099 RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2100 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2101 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2102 attr->starts.b = 0;
2103 attr->ends.b = input_shape.b;
2104 attr->strides.b = 1;
2105 return absl::OkStatus();
2106 }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2107 absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2108 if (tf_options->ellipsis_mask) {
2109 return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2110 }
2111 if (tf_options->new_axis_mask) {
2112 return absl::UnimplementedError("Slice does not support new_axis_mask.");
2113 }
2114 if (tf_options->shrink_axis_mask) {
2115 return absl::UnimplementedError(
2116 "Slice does not support shrink_axis_mask parameter. ");
2117 }
2118 return absl::OkStatus();
2119 }
2120 };
2121
2122 class TileOperationParser : public TFLiteOperationParser {
2123 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2124 absl::Status IsSupported(const TfLiteContext* context,
2125 const TfLiteNode* tflite_node,
2126 const TfLiteRegistration* registration) final {
2127 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2128 }
2129
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2130 absl::Status Parse(const TfLiteNode* tflite_node,
2131 const TfLiteRegistration* registration,
2132 GraphFloat32* graph, ObjectReader* reader) final {
2133 Node* node = graph->NewNode();
2134 node->operation.type = ToString(OperationType::TILE);
2135 RETURN_IF_ERROR(reader->AddInput(node, 0));
2136 RETURN_IF_ERROR(reader->AddOutputs(node));
2137 return absl::OkStatus();
2138 }
2139 };
2140
2141 // Builtin op version of TRANSPOSE_CONV.
2142 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2143 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2144 absl::Status IsSupported(const TfLiteContext* context,
2145 const TfLiteNode* tflite_node,
2146 const TfLiteRegistration* registration) final {
2147 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2148 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2149 }
2150
2151 // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2152 // input, and an optional bias) and allows configurable padding & stride.
2153 // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2154 absl::Status Parse(const TfLiteNode* tflite_node,
2155 const TfLiteRegistration* registration,
2156 GraphFloat32* graph, ObjectReader* reader) final {
2157 auto* node = graph->NewNode();
2158 node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2159 Value* input;
2160 RETURN_IF_ERROR(reader->ReadValue(2, &input));
2161 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2162 RETURN_IF_ERROR(reader->AddOutputs(node));
2163
2164 const TfLiteTransposeConvParams* tf_options;
2165 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2166
2167 ConvolutionTransposedAttributes attr;
2168 attr.stride = tf_options
2169 ? HW(tf_options->stride_height, tf_options->stride_width)
2170 : HW(1, 1);
2171 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2172 if (runtime_inputs == 2) {
2173 RETURN_IF_ERROR(reader->AddInput(node, 1));
2174 auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2175 attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2176 weights_shape.w, weights_shape.c);
2177 } else { // runtime_inputs == 1;
2178 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2179 }
2180 reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
2181
2182 UpdatePadding(tf_options->padding,
2183 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2184 node->operation.attributes = std::move(attr);
2185 return absl::OkStatus();
2186 }
2187 };
2188
2189 // Custom op version of TRANSPOSE_CONV.
2190 class TransposeConvCustomOperationParser : public TFLiteOperationParser {
2191 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2192 absl::Status IsSupported(const TfLiteContext* context,
2193 const TfLiteNode* tflite_node,
2194 const TfLiteRegistration* registration) final {
2195 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2196 }
2197
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2198 absl::Status Parse(const TfLiteNode* tflite_node,
2199 const TfLiteRegistration* registration,
2200 GraphFloat32* graph, ObjectReader* reader) final {
2201 auto* node = graph->NewNode();
2202 node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2203 RETURN_IF_ERROR(reader->AddInput(node, 0));
2204 RETURN_IF_ERROR(reader->AddOutputs(node));
2205
2206 const TfLiteTransposeConvParams* tf_options;
2207 auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
2208
2209 ConvolutionTransposedAttributes attr;
2210 attr.stride = status.ok()
2211 ? HW(tf_options->stride_height, tf_options->stride_width)
2212 : HW(1, 1);
2213 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2214 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
2215
2216 UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
2217 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2218 node->operation.attributes = std::move(attr);
2219 return absl::OkStatus();
2220 }
2221 };
2222
2223 class TransposeOperationParser : public TFLiteOperationParser {
2224 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2225 absl::Status IsSupported(const TfLiteContext* context,
2226 const TfLiteNode* tflite_node,
2227 const TfLiteRegistration* registration) final {
2228 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2229 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2230 }
2231
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2232 absl::Status Parse(const TfLiteNode* tflite_node,
2233 const TfLiteRegistration* registration,
2234 GraphFloat32* graph, ObjectReader* reader) final {
2235 Node* node = graph->NewNode();
2236 node->operation.type = ToString(OperationType::TRANSPOSE);
2237 RETURN_IF_ERROR(reader->AddInput(node, 0));
2238 RETURN_IF_ERROR(reader->AddOutputs(node));
2239
2240 TransposeAttributes attr;
2241 Tensor<Linear, DataType::INT32> perm;
2242 RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2243 std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2244 {Axis::HEIGHT, 1},
2245 {Axis::WIDTH, 2},
2246 {Axis::CHANNELS, 3}};
2247 if (perm.data.size() == 4) {
2248 attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2249 } else if (perm.data.size() == 3) {
2250 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2251 Axis::CHANNELS};
2252 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2253 attr.perm.h = 1;
2254 attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2255 attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2256 } else if (perm.data.size() == 2) {
2257 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2258 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2259 attr.perm.h = 1;
2260 attr.perm.w = 2;
2261 attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2262 } else {
2263 return absl::InvalidArgumentError(
2264 "Permutation for transpose is invalid.");
2265 }
2266
2267 node->operation.attributes = attr;
2268 return absl::OkStatus();
2269 }
2270 };
2271
2272 class Unpooling2DOperationParser : public TFLiteOperationParser {
2273 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2274 absl::Status IsSupported(const TfLiteContext* context,
2275 const TfLiteNode* tflite_node,
2276 const TfLiteRegistration* registration) final {
2277 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2278 }
2279
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2280 absl::Status Parse(const TfLiteNode* tflite_node,
2281 const TfLiteRegistration* registration,
2282 GraphFloat32* graph, ObjectReader* reader) final {
2283 Node* node = graph->NewNode();
2284 node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
2285 RETURN_IF_ERROR(reader->AddInput(node, 0));
2286 RETURN_IF_ERROR(reader->AddInput(node, 1));
2287 RETURN_IF_ERROR(reader->AddOutputs(node));
2288 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
2289 MaxUnpooling2DAttributes attr;
2290
2291 const TfLitePoolParams* tf_options;
2292 RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
2293
2294 attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
2295 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
2296 UpdatePadding(tf_options->padding, input_shape, &attr);
2297
2298 node->operation.attributes = attr;
2299
2300 auto output_value = graph->FindOutputs(node->id)[0];
2301 output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
2302 return absl::OkStatus();
2303 }
2304 };
2305
2306 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2307 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2308 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2309 absl::Status IsSupported(const TfLiteContext* context,
2310 const TfLiteNode* tflite_node,
2311 const TfLiteRegistration* registration) final {
2312 return absl::OkStatus();
2313 }
2314
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2315 absl::Status Parse(const TfLiteNode* tflite_node,
2316 const TfLiteRegistration* registration,
2317 GraphFloat32* graph, ObjectReader* reader) final {
2318 auto* node = graph->NewNode();
2319 node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2320 RETURN_IF_ERROR(reader->AddInput(node, 0));
2321 RETURN_IF_ERROR(reader->AddOutputs(node));
2322
2323 BatchToSpaceAttributes bs_attr;
2324 Tensor<Linear, DataType::INT32> block;
2325 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2326 if (block.shape.v != 2) {
2327 return absl::InternalError("Space has to be HxW.");
2328 }
2329 bs_attr.block.h = block.data[0];
2330 bs_attr.block.w = block.data[1];
2331
2332 Tensor<HW, DataType::INT32> crop;
2333 RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2334 auto crop_shape = crop.shape;
2335 if (crop_shape.h != 2 && crop_shape.w != 2) {
2336 return absl::InternalError("Space has to be HxW.");
2337 }
2338
2339 bs_attr.crop.prepended.h = crop.data[0];
2340 bs_attr.crop.prepended.w = crop.data[2];
2341
2342 bs_attr.crop.appended.h = crop.data[1];
2343 bs_attr.crop.appended.w = crop.data[3];
2344
2345 node->operation.attributes = std::move(bs_attr);
2346 return absl::OkStatus();
2347 }
2348 };
2349
2350 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2351 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2352 absl::Status IsSupported(const TfLiteContext* context,
2353 const TfLiteNode* tflite_node,
2354 const TfLiteRegistration* registration) final {
2355 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2356 }
2357
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2358 absl::Status Parse(const TfLiteNode* tflite_node,
2359 const TfLiteRegistration* registration,
2360 GraphFloat32* graph, ObjectReader* reader) final {
2361 auto* node = graph->NewNode();
2362 node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2363 RETURN_IF_ERROR(reader->AddInput(node, 0));
2364 RETURN_IF_ERROR(reader->AddOutputs(node));
2365 SpaceToBatchAttributes sb_attr;
2366 Tensor<Linear, DataType::INT32> block;
2367 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2368 if (block.shape.v != 2) {
2369 return absl::InternalError("Space has to be HxW.");
2370 }
2371 sb_attr.block.h = block.data[0];
2372 sb_attr.block.w = block.data[1];
2373
2374 Tensor<HW, DataType::INT32> padding;
2375 RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2376 auto padding_shape = padding.shape;
2377
2378 if (padding_shape.h != 2 && padding_shape.w != 2) {
2379 return absl::InternalError("Space has to be HxW.");
2380 }
2381
2382 sb_attr.padding.prepended.h = padding.data[0];
2383 sb_attr.padding.prepended.w = padding.data[2];
2384
2385 sb_attr.padding.appended.h = padding.data[1];
2386 sb_attr.padding.appended.w = padding.data[3];
2387
2388 node->operation.attributes = std::move(sb_attr);
2389 return absl::OkStatus();
2390 }
2391 };
2392
2393 class MeanOperationParser : public TFLiteOperationParser {
2394 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2395 absl::Status IsSupported(const TfLiteContext* context,
2396 const TfLiteNode* tflite_node,
2397 const TfLiteRegistration* registration) final {
2398 return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2399 }
2400
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2401 absl::Status Parse(const TfLiteNode* tflite_node,
2402 const TfLiteRegistration* registration,
2403 GraphFloat32* graph, ObjectReader* reader) final {
2404 auto* node = graph->NewNode();
2405 node->operation.type = ToString(OperationType::MEAN);
2406 RETURN_IF_ERROR(reader->AddInput(node, 0));
2407 RETURN_IF_ERROR(reader->AddOutputs(node));
2408
2409 MeanAttributes attr;
2410 const TfLiteTensor* input = reader->GetInputTensor(0);
2411 const TfLiteTensor* axes = reader->GetInputTensor(1);
2412 for (int i = 0; i < NumElements(axes->dims); i++) {
2413 Axis axis;
2414 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2415 attr.dims.insert(axis);
2416 }
2417 node->operation.attributes = attr;
2418 return absl::OkStatus();
2419 }
2420 };
2421
2422 class UnsupportedOperationParser : public TFLiteOperationParser {
2423 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2424 absl::Status IsSupported(const TfLiteContext* context,
2425 const TfLiteNode* tflite_node,
2426 const TfLiteRegistration* registration) final {
2427 return absl::UnimplementedError("Operation is not supported.");
2428 }
2429
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2430 absl::Status Parse(const TfLiteNode* tflite_node,
2431 const TfLiteRegistration* registration,
2432 GraphFloat32* graph, ObjectReader* reader) final {
2433 return absl::UnimplementedError("Operation is not supported.");
2434 }
2435 };
2436
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false)2437 absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
2438 const TfLiteRegistration* registration,
2439 bool allow_quant_ops = false) {
2440 return NewOperationParser(registration, allow_quant_ops)
2441 ->IsSupported(context, node, registration);
2442 }
2443
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,const std::vector<TfLiteType> & allowed_types)2444 bool IsAllAllowedTensors(TfLiteContext* context,
2445 const TfLiteIntArray* tensor_indices,
2446 const std::vector<TfLiteType>& allowed_types) {
2447 for (int i = 0; i < tensor_indices->size; ++i) {
2448 int tensor_idx = tensor_indices->data[i];
2449 if (tensor_idx == kTfLiteOptionalTensor) continue;
2450 const TfLiteTensor* t = &context->tensors[tensor_idx];
2451 if (t->dims && t->dims->size >= 5) {
2452 return false;
2453 }
2454 bool type_supported = false;
2455 for (auto allowed_type : allowed_types) {
2456 if (t->type == allowed_type) {
2457 type_supported = true;
2458 break;
2459 }
2460 }
2461 if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2462 return false;
2463 }
2464 }
2465 return true;
2466 }
2467 } // namespace
2468
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops)2469 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2470 const TfLiteRegistration* registration, bool allow_quant_ops) {
2471 const auto builtin_code = registration->builtin_code;
2472 switch (builtin_code) {
2473 case kTfLiteBuiltinAbs:
2474 return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2475 case kTfLiteBuiltinAdd:
2476 return std::make_unique<AddOperationParser>();
2477 case kTfLiteBuiltinAveragePool2d:
2478 return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2479 case kTfLiteBuiltinBatchMatmul:
2480 return std::make_unique<BatchedMatMulOperationParser>();
2481 case kTfLiteBuiltinCast:
2482 return std::make_unique<CastOperationParser>();
2483 case kTfLiteBuiltinConcatenation:
2484 return std::make_unique<ConcatenationOperationParser>();
2485 case kTfLiteBuiltinConv2d:
2486 return std::make_unique<Conv2DOperationParser>();
2487 case kTfLiteBuiltinCos:
2488 return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2489 case kTfLiteBuiltinDensify:
2490 return std::make_unique<DensifyOperationParser>();
2491 case kTfLiteBuiltinDepthwiseConv2d:
2492 return std::make_unique<DepthwiseConvolutionOperationParser>();
2493 case kTfLiteBuiltinDepthToSpace:
2494 return std::make_unique<DepthToSpaceOperationParser>();
2495 case kTfLiteBuiltinDequantize:
2496 if (allow_quant_ops) {
2497 return std::make_unique<DequantizeOperationParser>();
2498 }
2499 break;
2500 case kTfLiteBuiltinDiv:
2501 return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2502 case kTfLiteBuiltinEqual:
2503 return std::make_unique<ElementwiseOperationParser>(OperationType::EQUAL);
2504 case kTfLiteBuiltinElu:
2505 return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2506 case kTfLiteBuiltinExp:
2507 return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2508 case kTfLiteBuiltinFloor:
2509 return std::make_unique<ElementwiseOperationParser>(OperationType::FLOOR);
2510 case kTfLiteBuiltinFloorDiv:
2511 return std::make_unique<ElementwiseOperationParser>(
2512 OperationType::FLOOR_DIV);
2513 case kTfLiteBuiltinFloorMod:
2514 return std::make_unique<ElementwiseOperationParser>(
2515 OperationType::FLOOR_MOD);
2516 case kTfLiteBuiltinFullyConnected:
2517 return std::make_unique<FullyConnectedOperationParser>();
2518 case kTfLiteBuiltinGreater:
2519 return std::make_unique<ElementwiseOperationParser>(
2520 OperationType::GREATER);
2521 case kTfLiteBuiltinGreaterEqual:
2522 return std::make_unique<ElementwiseOperationParser>(
2523 OperationType::GREATER_EQUAL);
2524 case kTfLiteBuiltinHardSwish:
2525 return std::make_unique<HardSwishOperationParser>();
2526 case kTfLiteBuiltinLess:
2527 return std::make_unique<ElementwiseOperationParser>(OperationType::LESS);
2528 case kTfLiteBuiltinLessEqual:
2529 return std::make_unique<ElementwiseOperationParser>(
2530 OperationType::LESS_EQUAL);
2531 case kTfLiteBuiltinLogistic:
2532 return std::make_unique<ElementwiseOperationParser>(
2533 OperationType::SIGMOID);
2534 case kTfLiteBuiltinLog:
2535 return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
2536 case kTfLiteBuiltinLstm:
2537 return std::make_unique<LSTMOperationParser>();
2538 case kTfLiteBuiltinMaximum:
2539 return std::make_unique<ElementwiseOperationParser>(
2540 OperationType::MAXIMUM);
2541 case kTfLiteBuiltinMaxPool2d:
2542 return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2543 case kTfLiteBuiltinMean:
2544 return std::make_unique<MeanOperationParser>();
2545 case kTfLiteBuiltinMinimum:
2546 return std::make_unique<ElementwiseOperationParser>(
2547 OperationType::MINIMUM);
2548 case kTfLiteBuiltinMirrorPad:
2549 return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
2550 case kTfLiteBuiltinMul:
2551 return std::make_unique<MulOperationParser>();
2552 case kTfLiteBuiltinNeg:
2553 return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
2554 case kTfLiteBuiltinNotEqual:
2555 return std::make_unique<ElementwiseOperationParser>(
2556 OperationType::NOT_EQUAL);
2557 case kTfLiteBuiltinPack:
2558 return std::make_unique<PackOperationParser>();
2559 case kTfLiteBuiltinPad:
2560 return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
2561 case kTfLiteBuiltinPow:
2562 return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
2563 case kTfLiteBuiltinReduceMax:
2564 return std::make_unique<ReduceOperationParser>(
2565 OperationType::REDUCE_MAXIMUM);
2566 case kTfLiteBuiltinReduceMin:
2567 return std::make_unique<ReduceOperationParser>(
2568 OperationType::REDUCE_MINIMUM);
2569 case kTfLiteBuiltinReduceProd:
2570 return std::make_unique<ReduceOperationParser>(
2571 OperationType::REDUCE_PRODUCT);
2572 case kTfLiteBuiltinQuantize:
2573 if (allow_quant_ops) {
2574 return std::make_unique<QuantizeOperationParser>();
2575 }
2576 break;
2577 case kTfLiteBuiltinRelu:
2578 return std::make_unique<ReLUOperationParser>(0);
2579 case kTfLiteBuiltinRelu6:
2580 return std::make_unique<ReLUOperationParser>(6);
2581 case kTfLiteBuiltinReluN1To1:
2582 return std::make_unique<ClampOperationsParser>(-1.0, 1.0);
2583 case kTfLiteBuiltinLeakyRelu:
2584 return std::make_unique<ReLUOperationParser>(0);
2585 case kTfLiteBuiltinPrelu:
2586 return std::make_unique<PReLUOperationParser>();
2587 case kTfLiteBuiltinReshape:
2588 return std::make_unique<ReshapeOperationParser>();
2589 case kTfLiteBuiltinResizeBilinear:
2590 return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
2591 case kTfLiteBuiltinResizeNearestNeighbor:
2592 return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
2593 case kTfLiteBuiltinRsqrt:
2594 return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
2595 case kTfLiteBuiltinSin:
2596 return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
2597 case kTfLiteBuiltinSlice:
2598 return std::make_unique<SliceOperationParser>();
2599 case kTfLiteBuiltinSoftmax:
2600 return std::make_unique<SoftmaxOperationParser>();
2601 case kTfLiteBuiltinSpaceToDepth:
2602 return std::make_unique<SpaceToDepthOperationParser>();
2603 case kTfLiteBuiltinSplit:
2604 return std::make_unique<SplitOperationParser>();
2605 case kTfLiteBuiltinSplitV:
2606 return std::make_unique<SplitVOperationParser>();
2607 case kTfLiteBuiltinSqrt:
2608 return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
2609 case kTfLiteBuiltinSquare:
2610 return std::make_unique<ElementwiseOperationParser>(
2611 OperationType::SQUARE);
2612 case kTfLiteBuiltinSquaredDifference:
2613 return std::make_unique<ElementwiseOperationParser>(
2614 OperationType::SQUARED_DIFF);
2615 case kTfLiteBuiltinStridedSlice:
2616 return std::make_unique<StridedSliceOperationParser>();
2617 case kTfLiteBuiltinSub:
2618 return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
2619 case kTfLiteBuiltinSum:
2620 return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
2621 case kTfLiteBuiltinTanh:
2622 return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
2623 case kTfLiteBuiltinTile:
2624 return std::make_unique<TileOperationParser>();
2625 case kTfLiteBuiltinTranspose:
2626 return std::make_unique<TransposeOperationParser>();
2627 case kTfLiteBuiltinTransposeConv:
2628 return std::make_unique<TransposeConvBuiltinOperationParser>();
2629 case kTfLiteBuiltinCustom: {
2630 const absl::string_view custom_name = registration->custom_name;
2631 if (custom_name == "Convolution2DTransposeBias") {
2632 return std::make_unique<TransposeConvCustomOperationParser>();
2633 }
2634 if (custom_name == "MaxPoolingWithArgmax2D") {
2635 return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2636 }
2637 if (custom_name == "MaxUnpooling2D") {
2638 return std::make_unique<Unpooling2DOperationParser>();
2639 }
2640 if (custom_name == "Resampler") {
2641 return std::make_unique<ResamplerOperationParser>();
2642 }
2643 return NewCustomOperationParser(registration->custom_name);
2644 }
2645 }
2646 return std::make_unique<UnsupportedOperationParser>();
2647 }
2648
2649 // TODO(impjdi): Check number of input/output tensors and their dimensions.
2650 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions)2651 TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
2652 int max_delegated_partitions) {
2653 delegates::IsNodeSupportedFn node_supported_fn =
2654 [=](TfLiteContext* context, TfLiteNode* node,
2655 TfLiteRegistration* registration,
2656 std::string* unsupported_details) -> bool {
2657 const auto status =
2658 IsSupported(context, node, registration, allow_quant_ops);
2659 if (!status.ok()) {
2660 if (unsupported_details) {
2661 *unsupported_details = std::string(status.message());
2662 }
2663 return false;
2664 }
2665
2666 std::vector<TfLiteType> allowed_in_types = {kTfLiteFloat32, kTfLiteFloat16};
2667 std::vector<TfLiteType> allowed_out_types = {kTfLiteFloat32,
2668 kTfLiteFloat16};
2669 if (allow_quant_ops) {
2670 // Since we only check non-constant tensors, type cannot be Int32.
2671 allowed_in_types.push_back(kTfLiteInt8);
2672 allowed_in_types.push_back(kTfLiteUInt8);
2673 allowed_out_types.push_back(kTfLiteInt8);
2674 allowed_out_types.push_back(kTfLiteUInt8);
2675 }
2676 if (IsLogicalCode(registration->builtin_code)) {
2677 allowed_out_types.push_back(kTfLiteBool);
2678 }
2679 if (registration->builtin_code == kTfLiteBuiltinCast) {
2680 allowed_in_types.push_back(kTfLiteBool);
2681 }
2682 if (!IsAllAllowedTensors(context, node->inputs, allowed_in_types) ||
2683 !IsAllAllowedTensors(context, node->outputs, allowed_out_types)) {
2684 if (unsupported_details) {
2685 *unsupported_details =
2686 "OP is supported, but tensor type/shape doesn't supported.";
2687 }
2688 return false;
2689 }
2690 return true;
2691 };
2692
2693 delegates::FP16GraphPartitionHelper partition_helper(context,
2694 node_supported_fn);
2695 std::set<std::string> unsupported_nodes_info;
2696 if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
2697 return TfLiteIntArrayCreate(0);
2698 }
2699
2700 // By default, we simply get 1st largest partition as 'max_delegate_partions'
2701 // is set to 1 by default.
2702 std::vector<int> ops_to_replace =
2703 partition_helper.GetNodesOfFirstNLargestPartitions(
2704 max_delegated_partitions);
2705
2706 if (!unsupported_nodes_info.empty() &&
2707 partition_helper.num_total_nodes() > ops_to_replace.size()) {
2708 std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
2709 std::string error_message = absl::StrCat(
2710 "Following operations are not supported by GPU delegate:\n",
2711 unsupported, "\n");
2712 if (!ops_to_replace.empty()) {
2713 absl::StrAppend(
2714 &error_message, ops_to_replace.size(),
2715 " operations will run on the GPU, and the remaining ",
2716 partition_helper.num_total_nodes() - ops_to_replace.size());
2717 } else {
2718 absl::StrAppend(&error_message,
2719 "No operations will run on the GPU, and all ",
2720 partition_helper.num_total_nodes());
2721 }
2722 absl::StrAppend(&error_message, " operations will run on the CPU.");
2723 TF_LITE_KERNEL_LOG(context, error_message.c_str());
2724 }
2725 return ConvertVectorToTfLiteIntArray(ops_to_replace);
2726 }
2727
2728 // Creates inputs and outputs passed by io_tensors parameters in the resulting
2729 // graph. We force it to make sure that delegated subgraph has same order of
2730 // inputs and outputs with the original one. When delegated model is built from
2731 // the tflite model representation tensors are created lazily, so there is no
2732 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,const std::vector<int> & io_ids,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)2733 absl::Status PrecreateIOTensors(
2734 TfLiteContext* context, GraphFloat32* graph, const std::vector<int>& io_ids,
2735 absl::flat_hash_map<int, int>* quant_conversion_map,
2736 absl::flat_hash_map<int, Value*>* tensor_to_value) {
2737 for (const auto& id : io_ids) {
2738 const TfLiteTensor& tflite_tensor = context->tensors[id];
2739 if (tflite::IsConstantTensor(&tflite_tensor)) continue;
2740 RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
2741 context, tensor_to_value, quant_conversion_map, graph, id));
2742 }
2743 return absl::OkStatus();
2744 }
2745
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)2746 absl::Status CopyVariableTensorOutputs(
2747 TfLiteNode* tflite_node, TfLiteRegistration* registration,
2748 GraphFloat32* graph, ObjectReader& reader,
2749 const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
2750 absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
2751 new_variable_tensor_values);
2752 // Retrieve the final value id for the variable input tensors.
2753 for (int i = 0; i < tflite_node->inputs->size; i++) {
2754 int tensor_idx = tflite_node->inputs->data[i];
2755 Value* value;
2756 if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
2757 if (value->tensor.is_variable_input) {
2758 if (new_variable_tensor_values_copy.find(i) ==
2759 new_variable_tensor_values_copy.end()) {
2760 return absl::InvalidArgumentError(
2761 absl::StrCat(GetOpNameByRegistration(*registration),
2762 " did not provide a new value for the variable input "
2763 "tensor with index ",
2764 tensor_idx));
2765 } else {
2766 Node* node = graph->NewNode();
2767 node->operation.type = ToString(OperationType::COPY);
2768 RETURN_IF_ERROR(graph->AddConsumer(
2769 node->id, new_variable_tensor_values_copy.at(i)));
2770 RETURN_IF_ERROR(reader.AddUpdate(node, i));
2771 new_variable_tensor_values_copy.erase(
2772 new_variable_tensor_values_copy.find(i));
2773 }
2774 }
2775 }
2776 if (!new_variable_tensor_values_copy.empty()) {
2777 return absl::InvalidArgumentError(
2778 "More input variable tensors asked to be copied than present on the "
2779 "node");
2780 }
2781 return absl::OkStatus();
2782 }
2783
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2784 absl::Status BuildModel(TfLiteContext* context,
2785 const TfLiteDelegateParams* delegate_params,
2786 GraphFloat32* graph,
2787 absl::flat_hash_map<int, int>* quant_conversion_map) {
2788 std::vector<int> inputs(delegate_params->input_tensors->size);
2789 std::vector<int> outputs(delegate_params->output_tensors->size);
2790 for (int i = 0; i < delegate_params->input_tensors->size; i++) {
2791 inputs[i] = delegate_params->input_tensors->data[i];
2792 }
2793 for (int i = 0; i < delegate_params->output_tensors->size; i++) {
2794 outputs[i] = delegate_params->output_tensors->data[i];
2795 }
2796 return BuildModelEnforceIO(context, delegate_params, inputs, outputs, graph,
2797 quant_conversion_map);
2798 }
2799
BuildModelEnforceIO(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,const std::vector<int> & input_ids,const std::vector<int> & output_ids,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2800 absl::Status BuildModelEnforceIO(
2801 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2802 const std::vector<int>& input_ids, const std::vector<int>& output_ids,
2803 GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2804 std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
2805 std::vector<int> tflite_nodes;
2806 for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
2807 TfLiteNode* tflite_node = nullptr;
2808 TfLiteRegistration* registration = nullptr;
2809 RETURN_IF_ERROR(GetNodeAndRegistration(
2810 context, delegate_params->nodes_to_replace->data[i], &tflite_node,
2811 ®istration));
2812 if (registration->builtin_code == kTfLiteBuiltinDequantize &&
2813 context->tensors[tflite_node->inputs->data[0]].type ==
2814 TfLiteType::kTfLiteFloat16 &&
2815 context->tensors[tflite_node->inputs->data[0]].allocation_type ==
2816 TfLiteAllocationType::kTfLiteMmapRo) {
2817 // Ignore Fp16 Dequantize nodes only if they are the final nodes before
2818 // weights, i.e. no other nodes preceded them (e.g. DENSIFY).
2819 continue;
2820 }
2821 auto op_parser = NewOperationParser(
2822 registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
2823 if (!op_parser) {
2824 return absl::UnimplementedError(
2825 absl::StrCat("Operation ", registration->builtin_code, "(",
2826 registration->custom_name,
2827 ") is not supported by TFLite GPU Delegate."));
2828 }
2829 operations.push_back(std::move(op_parser));
2830 tflite_nodes.push_back(i);
2831 }
2832 absl::flat_hash_map<int, Value*> tensor_to_value;
2833 std::vector<ValueId> variable_inputs_to_value_id;
2834
2835 RETURN_IF_ERROR(PrecreateIOTensors(context, graph, input_ids,
2836 quant_conversion_map, &tensor_to_value));
2837 RETURN_IF_ERROR(PrecreateIOTensors(context, graph, output_ids,
2838 quant_conversion_map, &tensor_to_value));
2839 for (int i = 0; i < operations.size(); ++i) {
2840 TfLiteNode* tflite_node;
2841 TfLiteRegistration* registration;
2842 RETURN_IF_ERROR(GetNodeAndRegistration(
2843 context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
2844 &tflite_node, ®istration));
2845 ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
2846 quant_conversion_map);
2847 const auto status =
2848 operations[i]->Parse(tflite_node, registration, graph, &reader);
2849 if (!status.ok()) {
2850 return absl::InternalError(absl::StrCat(
2851 GetOpNameByRegistration(*registration), ": ", status.message()));
2852 }
2853
2854 absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
2855 operations[i]->GetNewValueIdsForVariableInputNodes();
2856
2857 RETURN_IF_ERROR(
2858 CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
2859 new_value_for_variable_input_tensors));
2860 }
2861
2862 // Variable input tensors expect to be unchanged throughout model execution.
2863 // They need to be an output of the graph in order to have them unchanged.
2864 for (auto value_id : variable_inputs_to_value_id) {
2865 if (!graph->IsGraphOutput(value_id)) {
2866 return absl::InvalidArgumentError(
2867 absl::StrCat("Variable input tensors must be a graph output. Value ",
2868 value_id, " is not a graph output"));
2869 }
2870 }
2871 return absl::OkStatus();
2872 }
2873
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2874 absl::Status BuildFinalModel(
2875 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2876 GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2877 RETURN_IF_ERROR(
2878 BuildModel(context, delegate_params, graph, quant_conversion_map));
2879
2880 // Apply general transformations on the graph.
2881 ModelTransformer transformer(graph);
2882 if (!ApplyModelTransformations(&transformer)) {
2883 return absl::InternalError("Graph transformations failed");
2884 }
2885 return absl::OkStatus();
2886 }
2887
2888 namespace {
2889
2890 class DelegateContext {
2891 public:
2892 struct DelegateData {
2893 std::vector<int> input_ids;
2894 std::vector<int> output_ids;
2895 GraphFloat32* graph;
2896 std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
2897 };
Init(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)2898 bool Init(TfLiteContext* context,
2899 const TfLiteDelegateParams* delegate_params) {
2900 const auto* delegate_data =
2901 reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
2902 return delegate_data->graph &&
2903 BuildModelEnforceIO(context, delegate_params,
2904 delegate_data->input_ids,
2905 delegate_data->output_ids, delegate_data->graph,
2906 delegate_data->quant_conversion_map.get())
2907 .ok();
2908 }
2909 };
2910
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)2911 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
2912 TfLiteRegistration registration;
2913 registration.init = [](TfLiteContext* context, const char* buffer,
2914 size_t) -> void* {
2915 auto* delegate_context = new DelegateContext();
2916 if (!delegate_context->Init(
2917 context, reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
2918 delete delegate_context;
2919 return nullptr;
2920 }
2921 return delegate_context;
2922 };
2923 registration.free = [](TfLiteContext* context, void* buffer) -> void {
2924 delete reinterpret_cast<DelegateContext*>(buffer);
2925 };
2926 registration.prepare = [](TfLiteContext* context,
2927 TfLiteNode* node) -> TfLiteStatus {
2928 return node->user_data ? kTfLiteOk : kTfLiteError;
2929 };
2930 registration.invoke = nullptr;
2931 registration.custom_name = nullptr;
2932
2933 const auto* delegate_data =
2934 reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
2935 TfLiteIntArray* ops_to_replace = GetOpsToReplace(
2936 context, static_cast<bool>(delegate_data->quant_conversion_map));
2937 const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
2938 context, registration, ops_to_replace, delegate);
2939 TfLiteIntArrayFree(ops_to_replace);
2940 return status;
2941 }
2942
2943 } // namespace
2944
BuildFromFlatBuffer(const tflite::FlatBufferModel & flatbuffer,const tflite::OpResolver & op_resolver,GraphFloat32 * graph,bool allow_quant_ops)2945 absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
2946 const tflite::OpResolver& op_resolver,
2947 GraphFloat32* graph, bool allow_quant_ops) {
2948 std::unique_ptr<tflite::Interpreter> interpreter;
2949 tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
2950 if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
2951 return absl::InternalError("Unable to prepare TfLite interpreter.");
2952 }
2953 TfLiteDelegate delegate;
2954
2955 DelegateContext::DelegateData delegate_data{interpreter->inputs(),
2956 interpreter->outputs(), graph};
2957 if (allow_quant_ops) {
2958 delegate_data.quant_conversion_map =
2959 absl::make_unique<absl::flat_hash_map<int, int>>();
2960 }
2961
2962 delegate.data_ = &delegate_data;
2963 delegate.flags = kTfLiteDelegateFlagsNone;
2964 delegate.Prepare = DelegatePrepare;
2965 delegate.CopyFromBufferHandle = nullptr;
2966 delegate.CopyToBufferHandle = nullptr;
2967 delegate.FreeBufferHandle = nullptr;
2968
2969 if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
2970 return absl::InternalError("Conversion from TfLite model failed.");
2971 }
2972
2973 ModelTransformer transformer(graph);
2974 if (!ApplyModelTransformations(&transformer)) {
2975 return absl::InternalError("Graph transformations failed");
2976 }
2977
2978 return absl::OkStatus();
2979 }
2980
2981 } // namespace gpu
2982 } // namespace tflite
2983