• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019-2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "tensorflow/lite/builtin_ops.h"
34 #include "tensorflow/lite/c/builtin_op_data.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
37 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
38 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
39 #include "tensorflow/lite/delegates/gpu/common/model.h"
40 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
41 #include "tensorflow/lite/delegates/gpu/common/model_builder_internal.h"
42 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
43 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
44 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
45 #include "tensorflow/lite/delegates/gpu/common/operations.h"
46 #include "tensorflow/lite/delegates/gpu/common/shape.h"
47 #include "tensorflow/lite/delegates/gpu/common/status.h"
48 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
49 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
50 #include "tensorflow/lite/delegates/utils.h"
51 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
52 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
53 #include "tensorflow/lite/kernels/kernel_util.h"
54 #include "tensorflow/lite/tools/versioning/gpu_compatibility.h"
55 #include "tensorflow/lite/util.h"
56 
57 namespace tflite {
58 namespace gpu {
59 namespace {
60 
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)61 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
62                                          int bias_tensor_id,
63                                          ObjectReader* reader,
64                                          FullyConnectedAttributes* attr) {
65   Tensor<HW, DataType::FLOAT32> weights;
66   RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
67   attr->weights.data = std::move(weights.data);
68   attr->weights.id = weights.id;
69   attr->weights.shape.h = 1;
70   attr->weights.shape.w = 1;
71   attr->weights.shape.o = weights.shape.h;
72   attr->weights.shape.i = weights.shape.w;
73   reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError();  // optional
74   return absl::OkStatus();
75 }
76 
77 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)78 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
79                                  const ParamsT** tf_options) {
80   *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
81   if (!*tf_options) {
82     return absl::InternalError("Unable to retrieve builtin_data.");
83   }
84   return absl::OkStatus();
85 }
86 
87 template <typename ParamsT>
RetrieveCustomInitialData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)88 absl::Status RetrieveCustomInitialData(const TfLiteNode* tflite_node,
89                                        const ParamsT** tf_options) {
90   *tf_options = static_cast<const ParamsT*>(tflite_node->custom_initial_data);
91   if (!*tf_options) {
92     return absl::InternalError("Unable to retrieve custom_initial_data.");
93   }
94   return absl::OkStatus();
95 }
96 
97 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)98 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
99   ConstTensorAttributes attr;
100   attr.tensor = std::move(t);
101   Node* node = graph->NewNode();
102   node->operation.attributes = attr;
103   node->operation.type = ToString(OperationType::CONSTANT);
104   *value = graph->NewValue();
105   RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
106   // Keep data inside this tensor.
107   (*value)->tensor.ref = attr.tensor.id;
108   (*value)->tensor.type = attr.tensor.kType;
109   (*value)->tensor.shape = attr.tensor.shape;
110   return absl::OkStatus();
111 }
112 
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)113 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
114                                         TensorOrScalar* tensor_or_scalar) {
115   const std::string& opname = node->operation.type;
116 
117   // Determine runtime/constant tensors.
118   const TfLiteTensor* input0 = reader->GetInputTensor(0);
119   if (!input0) {
120     return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
121                                       opname);
122   }
123   const TfLiteTensor* input1 = reader->GetInputTensor(1);
124   if (!input1) {
125     return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
126                                       opname);
127   }
128   const bool constant_tensor0 = IsConstantTensor(input0);
129   const bool constant_tensor1 = IsConstantTensor(input1);
130   if (constant_tensor0 && constant_tensor1) {
131     return absl::InvalidArgumentError("No runtime input tensors for " + opname);
132   }
133   const bool runtime_tensor0 = !constant_tensor0;
134   const bool runtime_tensor1 = !constant_tensor1;
135 
136   if (runtime_tensor0 && runtime_tensor1) {
137     RETURN_IF_ERROR(reader->AddInput(node, 0));
138     RETURN_IF_ERROR(reader->AddInput(node, 1));
139   } else {
140     int runtime_tensor = 0;
141     int constant_tensor = 1;
142     TfLiteIntArray* constant_dims = input1->dims;
143     if (constant_tensor0 && runtime_tensor1) {
144       runtime_tensor = 1;
145       constant_tensor = 0;
146       constant_dims = input0->dims;
147     }
148     RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
149     if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
150       Tensor<Scalar, DataType::FLOAT32> tensor;
151       RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
152       *tensor_or_scalar = tensor.data[0];
153     } else {
154       if (CheckIfLinearConvertible(constant_dims).ok()) {
155         Tensor<Linear, DataType::FLOAT32> tensor;
156         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
157         *tensor_or_scalar = std::move(tensor);
158       } else {
159         Tensor<HWC, DataType::FLOAT32> tensor;
160         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
161         *tensor_or_scalar = std::move(tensor);
162       }
163     }
164   }
165   return absl::OkStatus();
166 }
167 
168 struct TensorInfo {
169   std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> producers;
170   std::vector<std::pair<TfLiteNode*, TfLiteRegistration*>> consumers;
171 };
172 
GetTensorInfo(const TfLiteContext * context,int tensor_id,TensorInfo * result)173 absl::Status GetTensorInfo(const TfLiteContext* context, int tensor_id,
174                            TensorInfo* result) {
175   TfLiteIntArray* execution_plan = nullptr;
176   if (context->GetExecutionPlan(const_cast<TfLiteContext*>(context),
177                                 &execution_plan) != kTfLiteOk) {
178     return absl::UnavailableError("Unable to get graph execution plan.");
179   }
180   for (int i = 0; i < execution_plan->size; ++i) {
181     const int node_index = execution_plan->data[i];
182     TfLiteNode* node = nullptr;
183     TfLiteRegistration* registration = nullptr;
184     if (context->GetNodeAndRegistration(const_cast<TfLiteContext*>(context),
185                                         node_index, &node,
186                                         &registration) != kTfLiteOk) {
187       return absl::UnavailableError(
188           "Unable to get node and registration for node.");
189     }
190     for (int j = 0; j < node->inputs->size; ++j) {
191       if (tensor_id == node->inputs->data[j]) {
192         result->consumers.push_back({node, registration});
193       }
194     }
195     for (int j = 0; j < node->outputs->size; ++j) {
196       if (tensor_id == node->outputs->data[j]) {
197         result->producers.push_back({node, registration});
198       }
199     }
200   }
201   return absl::OkStatus();
202 }
203 
IsLogicalCode(int32_t builtin_code)204 bool IsLogicalCode(int32_t builtin_code) {
205   return builtin_code == kTfLiteBuiltinGreater ||
206          builtin_code == kTfLiteBuiltinGreaterEqual ||
207          builtin_code == kTfLiteBuiltinLess ||
208          builtin_code == kTfLiteBuiltinLessEqual ||
209          builtin_code == kTfLiteBuiltinEqual ||
210          builtin_code == kTfLiteBuiltinNotEqual;
211 }
212 
IsLogicalOp(tflite::gpu::OperationType op_type)213 bool IsLogicalOp(tflite::gpu::OperationType op_type) {
214   return op_type == tflite::gpu::OperationType::GREATER ||
215          op_type == tflite::gpu::OperationType::GREATER_EQUAL ||
216          op_type == tflite::gpu::OperationType::LESS ||
217          op_type == tflite::gpu::OperationType::LESS_EQUAL ||
218          op_type == tflite::gpu::OperationType::EQUAL ||
219          op_type == tflite::gpu::OperationType::NOT_EQUAL;
220 }
221 
222 class AddOperationParser : public TFLiteOperationParser {
223  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)224   absl::Status IsSupported(const TfLiteContext* context,
225                            const TfLiteNode* tflite_node,
226                            const TfLiteRegistration* registration) final {
227     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
228     // TODO(eignasheva): Add shapes check.
229     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
230   }
231 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)232   absl::Status Parse(const TfLiteNode* tflite_node,
233                      const TfLiteRegistration* registration,
234                      GraphFloat32* graph, ObjectReader* reader) final {
235     // TFLite currently only supports 2 input ADDs.  Thus, the logic below only
236     // considers 2 input cases.  The underlying GPU shader programs can accept
237     // more inputs, but the logic below would have to be expanded.
238 
239     Node* node = graph->NewNode();
240     node->operation.type = ToString(OperationType::ADD);
241     RETURN_IF_ERROR(reader->AddOutputs(node));
242     ElementwiseAttributes attr;
243     RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
244     node->operation.attributes = std::move(attr);
245     const TfLiteAddParams* tf_options;
246     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
247     return MaybeFuseActivation(tf_options->activation, graph, node);
248   }
249 };
250 
251 class BatchedMatMulOperationParser : public TFLiteOperationParser {
252  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)253   absl::Status IsSupported(const TfLiteContext* context,
254                            const TfLiteNode* tflite_node,
255                            const TfLiteRegistration* registration) final {
256     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
257   }
258 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)259   absl::Status Parse(const TfLiteNode* tflite_node,
260                      const TfLiteRegistration* registration,
261                      GraphFloat32* graph, ObjectReader* reader) final {
262     Node* node = graph->NewNode();
263     node->operation.type = ToString(OperationType::BATCHED_MATMUL);
264     RETURN_IF_ERROR(reader->AddInput(node, 0));
265     RETURN_IF_ERROR(reader->AddInput(node, 1));
266     RETURN_IF_ERROR(reader->AddOutputs(node));
267     return absl::OkStatus();
268   }
269 };
270 
271 class CastOperationParser : public TFLiteOperationParser {
272  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)273   absl::Status IsSupported(const TfLiteContext* context,
274                            const TfLiteNode* tflite_node,
275                            const TfLiteRegistration* registration) final {
276     TensorInfo input_tensor_info;
277     RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->inputs->data[0],
278                                   &input_tensor_info));
279     if (input_tensor_info.producers.size() != 1 ||
280         input_tensor_info.consumers.size() != 1) {
281       return absl::UnavailableError("Not supported cast case");
282     }
283     if (IsLogicalCode(input_tensor_info.producers[0].second->builtin_code)) {
284       return CheckGpuDelegateCompatibility(context, tflite_node, registration);
285     }
286     return absl::UnimplementedError("Not supported Cast case.");
287   }
288 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)289   absl::Status Parse(const TfLiteNode* tflite_node,
290                      const TfLiteRegistration* registration,
291                      GraphFloat32* graph, ObjectReader* reader) final {
292     Node* node = graph->NewNode();
293     // Adding Identity reshape that will be removed.
294     node->operation.type = ToString(OperationType::RESHAPE);
295     RETURN_IF_ERROR(reader->AddInput(node, 0));
296     RETURN_IF_ERROR(reader->AddOutputs(node));
297     // New shape comes from output shape.
298     ReshapeAttributes attr;
299     attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
300     node->operation.attributes = attr;
301     return absl::OkStatus();
302   }
303 };
304 
305 class ClampOperationsParser : public TFLiteOperationParser {
306  public:
ClampOperationsParser(float clamp_a,float clamp_b)307   explicit ClampOperationsParser(float clamp_a, float clamp_b)
308       : clamp_a_(clamp_a), clamp_b_(clamp_b) {}
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)309   absl::Status IsSupported(const TfLiteContext* context,
310                            const TfLiteNode* tflite_node,
311                            const TfLiteRegistration* registration) final {
312     return absl::OkStatus();
313   }
314 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)315   absl::Status Parse(const TfLiteNode* tflite_node,
316                      const TfLiteRegistration* registration,
317                      GraphFloat32* graph, ObjectReader* reader) final {
318     // clamp(v, a, b) = clamp(v - a, 0.0, b - a) + a;
319     // We replace clamp(...) with sequence of elementwise ops:
320     // substaction -> usual relu with alpha = 0.0 -> addition.
321     // node_sub = v0 = v - a // add op (add -a)
322     // node_relu = v1 = clamp(v0, 0.0, clip); // relu op alpha = 0.0,
323     // clip = b - a;
324     // node_add = v2 = v1 + a // add op (add a)
325     Node* node_sub = graph->NewNode();
326     Node* node_relu = graph->NewNode();
327     Node* node_add = graph->NewNode();
328 
329     ElementwiseAttributes sub_attr;
330     sub_attr.param = -clamp_a_;
331     node_sub->operation.type = ToString(OperationType::ADD);
332     node_sub->operation.attributes = std::move(sub_attr);
333 
334     ReLUAttributes relu_attr;
335     relu_attr.alpha = 0.0f;
336     relu_attr.clip = clamp_b_ - clamp_a_;
337     node_relu->operation.type = ToString(OperationType::RELU);
338     node_relu->operation.attributes = relu_attr;
339 
340     ElementwiseAttributes add_attr;
341     add_attr.param = clamp_a_;
342     node_add->operation.type = ToString(OperationType::ADD);
343     node_add->operation.attributes = std::move(add_attr);
344 
345     RETURN_IF_ERROR(reader->AddInput(node_sub, 0));
346     auto input = graph->FindInputs(node_sub->id)[0];
347 
348     Value* v0 = graph->NewValue();
349     Value* v1 = graph->NewValue();
350     v0->tensor.type = input->tensor.type;
351     v0->tensor.shape = input->tensor.shape;
352     v1->tensor.type = input->tensor.type;
353     v1->tensor.shape = input->tensor.shape;
354 
355     RETURN_IF_ERROR(graph->SetProducer(node_sub->id, v0->id));
356     RETURN_IF_ERROR(graph->AddConsumer(node_relu->id, v0->id));
357     RETURN_IF_ERROR(graph->SetProducer(node_relu->id, v1->id));
358     RETURN_IF_ERROR(graph->AddConsumer(node_add->id, v1->id));
359 
360     RETURN_IF_ERROR(reader->AddOutputs(node_add));
361     return absl::OkStatus();
362   }
363 
364  private:
365   const float clamp_a_, clamp_b_;
366 };
367 
368 class ConcatenationOperationParser : public TFLiteOperationParser {
369  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)370   absl::Status IsSupported(const TfLiteContext* context,
371                            const TfLiteNode* tflite_node,
372                            const TfLiteRegistration* registration) final {
373     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
374 
375     // TODO(eignasheva): add proper tensor availability checking
376     // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
377     //   RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
378     // }
379     // TODO(eignasheva): add axis checking.
380     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
381   }
382 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)383   absl::Status Parse(const TfLiteNode* tflite_node,
384                      const TfLiteRegistration* registration,
385                      GraphFloat32* graph, ObjectReader* reader) final {
386     ConcatAttributes attr;
387     // Read inputs first to make sure const node is added to a graph before
388     // concat node to ensure topological order.
389     std::vector<const Value*> inputs;
390     for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
391       Value* value;
392       const auto status = reader->ReadValue(idx, &value);
393       if (status.ok()) {
394         inputs.push_back(value);
395       } else {
396         TensorFloat32 tensor;
397         RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
398         Value* value;
399         RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
400         inputs.push_back(value);
401       }
402     }
403 
404     Node* node = graph->NewNode();
405     node->operation.type = ToString(OperationType::CONCAT);
406     RETURN_IF_ERROR(reader->AddOutputs(node));
407     for (const Value* input : inputs) {
408       RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
409     }
410 
411     std::vector<BHWC> input_shapes;
412     for (auto input : graph->FindInputs(node->id)) {
413       input_shapes.push_back(input->tensor.shape);
414     }
415     RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
416 
417     // Guess axis.
418     BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
419     for (auto input : graph->FindInputs(node->id)) {
420       if (input->tensor.shape.h != output_shape.h) {
421         attr.axis = Axis::HEIGHT;
422         break;
423       }
424       if (input->tensor.shape.w != output_shape.w) {
425         attr.axis = Axis::WIDTH;
426         break;
427       }
428       if (input->tensor.shape.c != output_shape.c) {
429         attr.axis = Axis::CHANNELS;
430         break;
431       }
432     }
433     const TfLiteConcatenationParams* tf_options;
434     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
435     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
436     node->operation.attributes = attr;
437     return absl::OkStatus();
438   }
439 
440  private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)441   absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
442     *axis = Axis::BATCH;
443     for (int i = 1; i < input_shapes.size(); i++) {
444       if (input_shapes[0].h != input_shapes[i].h &&
445           input_shapes[0].w != input_shapes[i].w &&
446           input_shapes[0].c != input_shapes[i].c) {
447         *axis = Axis::HEIGHT;
448         break;
449       }
450     }
451     if (*axis == Axis::BATCH) return absl::OkStatus();
452     for (int i = 1; i < input_shapes.size(); i++) {
453       if (input_shapes[0].b != input_shapes[i].b &&
454           input_shapes[0].w != input_shapes[i].w &&
455           input_shapes[0].c != input_shapes[i].c) {
456         *axis = Axis::WIDTH;
457         break;
458       }
459     }
460     if (*axis == Axis::HEIGHT) return absl::OkStatus();
461     for (int i = 1; i < input_shapes.size(); i++) {
462       if (input_shapes[0].b != input_shapes[i].b &&
463           input_shapes[0].h != input_shapes[i].h &&
464           input_shapes[0].c != input_shapes[i].c) {
465         *axis = Axis::CHANNELS;
466         break;
467       }
468     }
469     if (*axis == Axis::WIDTH) return absl::OkStatus();
470     for (int i = 1; i < input_shapes.size(); i++) {
471       if (input_shapes[0].b != input_shapes[i].b &&
472           input_shapes[0].w != input_shapes[i].w &&
473           input_shapes[0].h != input_shapes[i].h) {
474         return absl::UnimplementedError(
475             "Can concatenate tensors only by batch, height, width, or "
476             "channels.");
477       }
478     }
479     return absl::OkStatus();
480   }
481 };
482 
483 class Conv2DOperationParser : public TFLiteOperationParser {
484  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)485   absl::Status IsSupported(const TfLiteContext* context,
486                            const TfLiteNode* tflite_node,
487                            const TfLiteRegistration* registration) final {
488     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5));
489     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
490   }
491 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)492   absl::Status Parse(const TfLiteNode* tflite_node,
493                      const TfLiteRegistration* registration,
494                      GraphFloat32* graph, ObjectReader* reader) final {
495     Node* node = graph->NewNode();
496     node->operation.type = ToString(OperationType::CONVOLUTION_2D);
497     RETURN_IF_ERROR(reader->AddInput(node, 0));
498     RETURN_IF_ERROR(reader->AddOutputs(node));
499 
500     Convolution2DAttributes attr;
501     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
502     if (runtime_inputs == 2) {
503       RETURN_IF_ERROR(reader->AddInput(node, 1));
504     } else {  // runtime_inputs == 1;
505       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
506     }
507     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
508 
509     const TfLiteConvParams* tf_options;
510     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
511     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
512     attr.dilations = HW(tf_options->dilation_height_factor,
513                         tf_options->dilation_width_factor);
514     UpdatePadding(tf_options->padding,
515                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
516     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
517     node->operation.attributes = std::move(attr);
518     return absl::OkStatus();
519   }
520 };
521 
522 // Doesn't have a kernel implementation.
523 class DensifyOperationParser : public TFLiteOperationParser {
524  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)525   absl::Status IsSupported(const TfLiteContext* context,
526                            const TfLiteNode* tflite_node,
527                            const TfLiteRegistration* registration) final {
528     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
529     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
530   }
531 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)532   absl::Status Parse(const TfLiteNode* tflite_node,
533                      const TfLiteRegistration* registration,
534                      GraphFloat32* graph, ObjectReader* reader) final {
535     Node* node = graph->NewNode();
536     node->operation.type = ToString(OperationType::DENSIFY);
537     const TfLiteTensor* const_tensor = reader->GetInputTensor(0);
538     if (!const_tensor->sparsity) {
539       return absl::InvalidArgumentError("Input tensor must be sparse.");
540     }
541     TensorFloat32 sparse_tensor;
542     RETURN_IF_ERROR(reader->ReadTensor(0, &sparse_tensor));
543     DensifyAttributes attributes;
544     attributes.tensor = std::move(sparse_tensor);
545     node->operation.attributes = attributes;
546     return reader->AddOutputs(node);
547   }
548 };
549 
550 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
551  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)552   absl::Status IsSupported(const TfLiteContext* context,
553                            const TfLiteNode* tflite_node,
554                            const TfLiteRegistration* registration) final {
555     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
556     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
557   }
558 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)559   absl::Status Parse(const TfLiteNode* tflite_node,
560                      const TfLiteRegistration* registration,
561                      GraphFloat32* graph, ObjectReader* reader) final {
562     Node* node = graph->NewNode();
563     node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
564     RETURN_IF_ERROR(reader->AddInput(node, 0));
565     RETURN_IF_ERROR(reader->AddOutputs(node));
566 
567     DepthwiseConvolution2DAttributes attr;
568     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
569     if (runtime_inputs == 2) {
570       RETURN_IF_ERROR(reader->AddInput(node, 1));
571     } else {  // runtime_inputs == 1;
572       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
573     }
574     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
575     const TfLiteDepthwiseConvParams* tf_options;
576     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
577     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
578     attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
579                         std::max(1, tf_options->dilation_width_factor));
580     UpdatePadding(tf_options->padding,
581                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
582     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
583     const int depth_multiplier = tf_options->depth_multiplier;
584     if (depth_multiplier != 1) {
585       const TfLiteTensor* input = reader->GetInputTensor(0);
586       const TfLiteTensor* filter = reader->GetInputTensor(1);
587       const TfLiteTensor* output = reader->GetOutputTensor(0);
588       TransposeWeights(input, filter, output, depth_multiplier, &attr);
589     }
590     node->operation.attributes = std::move(attr);
591     return absl::OkStatus();
592   }
593 
594  private:
595   // TFLite CPU stores weights as:
596   //   [1, kernel_height, kernel_width, input_depth * depth_multiplier]
597   // TFLite GPU stores weights as:
598   //   [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)599   static void TransposeWeights(const TfLiteTensor* input,
600                                const TfLiteTensor* filter,
601                                const TfLiteTensor* output, int depth_multiplier,
602                                DepthwiseConvolution2DAttributes* attr) {
603     const int input_depth = input->dims->data[3];
604     const int filter_height = filter->dims->data[1];
605     const int filter_width = filter->dims->data[2];
606     const int output_depth = output->dims->data[3];
607     Tensor<OHWI, DataType::FLOAT32> weights;
608     weights.id = attr->weights.id;
609     weights.shape =
610         OHWI(output_depth, filter_height, filter_width, input_depth);
611     weights.data.resize(weights.shape.DimensionsProduct());
612     float* dst = &weights.data[0];
613     for (int j = 0; j < output_depth; ++j) {
614       const float* src = attr->weights.data.data() + j;
615       for (int i = 0; i < filter_height * filter_width; ++i) {
616         *dst = *src;
617         dst++;
618         src += output_depth;
619       }
620     }
621     attr->weights = std::move(weights);
622   }
623 };
624 
625 class DepthToSpaceOperationParser : public TFLiteOperationParser {
626  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)627   absl::Status IsSupported(const TfLiteContext* context,
628                            const TfLiteNode* tflite_node,
629                            const TfLiteRegistration* registration) final {
630     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
631   }
632 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)633   absl::Status Parse(const TfLiteNode* tflite_node,
634                      const TfLiteRegistration* registration,
635                      GraphFloat32* graph, ObjectReader* reader) final {
636     Node* node = graph->NewNode();
637     node->operation.type = ToString(OperationType::DEPTH_TO_SPACE);
638     RETURN_IF_ERROR(reader->AddInput(node, 0));
639     RETURN_IF_ERROR(reader->AddOutputs(node));
640     const TfLiteDepthToSpaceParams* tf_options;
641     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
642     SpaceToDepthAttributes attr;
643     attr.block_size = tf_options->block_size;
644     node->operation.attributes = attr;
645     return absl::OkStatus();
646   }
647 };
648 
649 class DequantizeOperationParser : public TFLiteOperationParser {
650  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)651   absl::Status IsSupported(const TfLiteContext* context,
652                            const TfLiteNode* tflite_node,
653                            const TfLiteRegistration* registration) final {
654     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
655     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
656   }
657 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)658   absl::Status Parse(const TfLiteNode* tflite_node,
659                      const TfLiteRegistration* registration,
660                      GraphFloat32* graph, ObjectReader* reader) final {
661     // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
662     // with floating-point versions of the original tensors.
663     Node* node = graph->NewNode();
664     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
665     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
666     if (runtime_inputs == 1) {
667       // Non-constant dequantization.
668       RETURN_IF_ERROR(reader->AddInput(node, 0));
669     } else {
670       // TODO(b/181274192): Optimize out this constant dequantization from the
671       // graph later.
672       TensorFloat32 tensor;
673       RETURN_IF_ERROR(reader->ReadTensor(0, &tensor));
674       Value* value;
675       RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
676       // Need to retain the quant params from the original constant input.
677       const TfLiteTensor* tflite_input = reader->GetInputTensor(0);
678       value->quant_params.emplace();
679       RETURN_IF_ERROR(
680           PopulateQuantParams(*tflite_input, &value->quant_params.value()));
681       RETURN_IF_ERROR(graph->AddConsumer(node->id, value->id));
682     }
683     RETURN_IF_ERROR(reader->AddOutputs(node));
684 
685     // Quantization attributes should already be present in the input tensor.
686     auto input_value = graph->FindInputs(node->id)[0];
687     if (!input_value->quant_params) {
688       if (runtime_inputs == 1) {
689         // DEQUANTIZE op is preceded by DENSIFY op and doesn't have any
690         // quantization params. The DEQUANTIZE op latter will be removed from
691         // the graph in `MergeDensify` graph transformation.
692         return absl::OkStatus();
693       }
694       return absl::InvalidArgumentError(
695           "Encountered Dequantize input with no quant params");
696     }
697     QuantizeAndDequantizeAttributes attr;
698     attr.min = input_value->quant_params.value().min;
699     attr.max = input_value->quant_params.value().max;
700     attr.scale = input_value->quant_params.value().scale;
701 
702     node->operation.attributes = attr;
703     return absl::OkStatus();
704   }
705 };
706 
707 class ElementwiseOperationParser : public TFLiteOperationParser {
708  public:
ElementwiseOperationParser(OperationType operation_type)709   explicit ElementwiseOperationParser(OperationType operation_type)
710       : operation_type_(operation_type) {}
711 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)712   absl::Status IsSupported(const TfLiteContext* context,
713                            const TfLiteNode* tflite_node,
714                            const TfLiteRegistration* registration) final {
715     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
716     if (IsLogicalOp(operation_type_)) {
717       TensorInfo output_tensor_info;
718       RETURN_IF_ERROR(GetTensorInfo(context, tflite_node->outputs->data[0],
719                                     &output_tensor_info));
720       if (output_tensor_info.producers.size() != 1 ||
721           output_tensor_info.consumers.size() != 1) {
722         return absl::UnavailableError("Not supported logical op case");
723       }
724       if (output_tensor_info.consumers[0].second->builtin_code ==
725           kTfLiteBuiltinCast) {
726         return absl::OkStatus();
727       } else {
728         return absl::UnimplementedError("Not supported logical op case.");
729       }
730     }
731     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
732   }
733 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)734   absl::Status Parse(const TfLiteNode* tflite_node,
735                      const TfLiteRegistration* registration,
736                      GraphFloat32* graph, ObjectReader* reader) final {
737     Node* node = graph->NewNode();
738     node->operation.type = ToString(operation_type_);
739 
740     if (IsOneArgumentOperation()) {
741       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
742                                                         /*runtime_inputs=*/1,
743                                                         /*const_inputs=*/0,
744                                                         /*outputs=*/1));
745 
746       RETURN_IF_ERROR(reader->AddInput(node, 0));
747     } else if (IsTwoArgumentOperation() &&
748                reader
749                    ->VerifyInputsConstsOutputs(tflite_node,
750                                                /*runtime_inputs=*/2,
751                                                /*const_inputs=*/0,
752                                                /*outputs=*/1)
753                    .ok()) {
754       if (tflite_node->inputs->size != 2) {
755         return absl::InvalidArgumentError("Applies only two input tensors");
756       }
757       RETURN_IF_ERROR(reader->AddInput(node, 0));
758       RETURN_IF_ERROR(reader->AddInput(node, 1));
759 
760       TfLiteFusedActivation activation = kTfLiteActNone;
761       switch (operation_type_) {
762         case OperationType::SUB: {
763           const TfLiteSubParams* tf_options;
764           if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
765             activation = tf_options->activation;
766           }
767           break;
768         }
769         case OperationType::DIV: {
770           const TfLiteDivParams* tf_options;
771           if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
772             activation = tf_options->activation;
773           }
774           break;
775         }
776         default:
777           // No activation expected.
778           activation = kTfLiteActNone;
779       }
780 
781       if (activation) {
782         RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
783       }
784     } else if (IsTwoArgumentOperationWithConst()) {
785       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
786                                                         /*runtime_inputs=*/1,
787                                                         /*const_inputs=*/1,
788                                                         /*outputs=*/1));
789       ElementwiseAttributes attr;
790       RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
791       attr.runtime_tensor_is_second =
792           IsConstantTensor(reader->GetInputTensor(0));
793       node->operation.attributes = std::move(attr);
794     } else {
795       return absl::InvalidArgumentError("Incorrect operation type passed");
796     }
797 
798     return reader->AddOutputs(node);
799   }
800 
801  private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const802   absl::Status GetActivation(const TfLiteNode* tflite_node,
803                              TfLiteFusedActivation* activation) const {
804     if (operation_type_ == OperationType::DIV) {
805       const TfLiteDivParams* tf_options;
806       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
807       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
808       return absl::OkStatus();
809     }
810     if (operation_type_ == OperationType::SUB) {
811       const TfLiteSubParams* tf_options;
812       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
813       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
814       return absl::OkStatus();
815     }
816 
817     // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
818     // TfLiteXxxParams.activation.
819     *activation = kTfLiteActNone;
820     return absl::OkStatus();
821   }
822 
IsOneArgumentOperation() const823   bool IsOneArgumentOperation() const {
824     switch (operation_type_) {
825       case OperationType::ABS:
826       case OperationType::COPY:
827       case OperationType::COS:
828       case OperationType::ELU:
829       case OperationType::EXP:
830       case OperationType::FLOOR:
831       case OperationType::LOG:
832       case OperationType::NEG:
833       case OperationType::RSQRT:
834       case OperationType::SIGMOID:
835       case OperationType::SIN:
836       case OperationType::SQRT:
837       case OperationType::SQUARE:
838       case OperationType::TANH:
839         return true;
840       default:
841         return false;
842     }
843   }
844 
IsTwoArgumentOperation() const845   bool IsTwoArgumentOperation() const {
846     switch (operation_type_) {
847       case OperationType::DIV:
848       case OperationType::EQUAL:
849       case OperationType::FLOOR_DIV:
850       case OperationType::FLOOR_MOD:
851       case OperationType::GREATER:
852       case OperationType::GREATER_EQUAL:
853       case OperationType::LESS:
854       case OperationType::LESS_EQUAL:
855       case OperationType::MAXIMUM:
856       case OperationType::MINIMUM:
857       case OperationType::NOT_EQUAL:
858       case OperationType::POW:
859       case OperationType::SQUARED_DIFF:
860       case OperationType::SUB:
861         return true;
862       default:
863         return false;
864     }
865   }
866 
IsTwoArgumentOperationWithConst() const867   bool IsTwoArgumentOperationWithConst() const {
868     switch (operation_type_) {
869       case OperationType::DIV:
870       case OperationType::EQUAL:
871       case OperationType::FLOOR_DIV:
872       case OperationType::FLOOR_MOD:
873       case OperationType::GREATER:
874       case OperationType::GREATER_EQUAL:
875       case OperationType::LESS:
876       case OperationType::LESS_EQUAL:
877       case OperationType::MAXIMUM:
878       case OperationType::MINIMUM:
879       case OperationType::NOT_EQUAL:
880       case OperationType::POW:
881       case OperationType::SQUARED_DIFF:
882       case OperationType::SUB:
883         return true;
884       default:
885         return false;
886     }
887   }
888 
889   OperationType operation_type_;
890 };
891 
892 class FullyConnectedOperationParser : public TFLiteOperationParser {
893  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)894   absl::Status IsSupported(const TfLiteContext* context,
895                            const TfLiteNode* tflite_node,
896                            const TfLiteRegistration* registration) final {
897     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
898     // TODO(eignasheva): check input shape
899     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
900   }
901 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)902   absl::Status Parse(const TfLiteNode* tflite_node,
903                      const TfLiteRegistration* registration,
904                      GraphFloat32* graph, ObjectReader* reader) final {
905     const TfLiteFullyConnectedParams* tf_options;
906     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
907 
908     if (reader->GetNumberOfRuntimeInputs() == 2) {
909       // Create Convolution2D, so as it supports runtime weights.
910       Node* node = graph->NewNode();
911       node->operation.type = ToString(OperationType::CONVOLUTION_2D);
912       RETURN_IF_ERROR(reader->AddInput(node, 0));
913       RETURN_IF_ERROR(reader->AddInput(node, 1));
914       RETURN_IF_ERROR(reader->AddOutputs(node));
915 
916       Convolution2DAttributes attr;
917       reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
918 
919       attr.strides = HW(1, 1);
920       attr.dilations = HW(1, 1);
921       attr.padding.appended = HW(0, 0);
922       attr.padding.prepended = HW(0, 0);
923       RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
924       node->operation.attributes = std::move(attr);
925       return absl::OkStatus();
926     }
927     Node* node = graph->NewNode();
928     RETURN_IF_ERROR(reader->AddInput(node, 0));
929 
930     if (tf_options->weights_format !=
931         kTfLiteFullyConnectedWeightsFormatDefault) {
932       return absl::UnimplementedError(
933           "Unsupported FullyConnected weights format.");
934     }
935 
936     FullyConnectedAttributes attr;
937     RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
938     const int weights_width = attr.weights.shape.i;
939 
940     auto input = graph->FindInputs(node->id)[0];
941     int batch_size = input->tensor.shape.b;
942     if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) {
943       return absl::UnimplementedError(
944           "Amount of input data should match weights width");
945     }
946 
947     Node* conv = node;
948     if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
949       auto& reshape = node;
950       conv = graph->NewNode();  // reset conv pointer!
951       Value* reshaped_value = graph->NewValue();
952       reshaped_value->tensor.type = DataType::FLOAT32;
953       reshaped_value->tensor.shape =
954           BHWC(input->tensor.shape.b, 1, 1, weights_width);
955       RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
956       reshape->operation.type = ToString(OperationType::RESHAPE);
957       ReshapeAttributes attr;
958       attr.new_shape = reshaped_value->tensor.shape;
959       reshape->operation.attributes = attr;
960       RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
961     }
962 
963     conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
964     conv->operation.attributes = std::move(attr);
965     absl::Status result = reader->AddOutputs(conv);
966     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
967 
968     return result;
969   }
970 };
971 
972 class HardSwishOperationParser : public TFLiteOperationParser {
973  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)974   absl::Status IsSupported(const TfLiteContext* context,
975                            const TfLiteNode* tflite_node,
976                            const TfLiteRegistration* registration) final {
977     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
978   }
979 
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)980   absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
981                      GraphFloat32* graph, ObjectReader* reader) final {
982     Node* node = graph->NewNode();
983     node->operation.type = ToString(OperationType::HARD_SWISH);
984     RETURN_IF_ERROR(reader->AddInput(node, 0));
985     return reader->AddOutputs(node);
986   }
987 };
988 
989 // Basic LSTM Cell:
990 //
991 //  1name = name is at input  index 1
992 //  name1 = name is at output index 1
993 //
994 //    0input     1prev_activ
995 //       \        /
996 //        [[concat]]
997 //             \
998 //       concat_temp2  2weights  3biases
999 //              \      /        /
1000 //             [[fully-connected]]
1001 //               \
1002 //         activ_temp3    4prev_state
1003 //                 \      /
1004 //                 [[LSTM]]
1005 //                 /      \
1006 //           new_state1    activation0
1007 //
1008 // For full LSTM cells, see this blog post:
1009 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
1010 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
1011 // described in that post, this code also adds the following optional features:
1012 // - Configurable activations (sigmoid or TANH)
1013 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
1014 // - Output projection:
1015 //     https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
1016 // - Configurable clipping of cell state and output state.
1017 class LSTMOperationParser : public TFLiteOperationParser {
1018  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1019   absl::Status IsSupported(const TfLiteContext* context,
1020                            const TfLiteNode* tflite_node,
1021                            const TfLiteRegistration* registration) final {
1022     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
1023     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1024   }
1025 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1026   absl::Status Parse(const TfLiteNode* tflite_node,
1027                      const TfLiteRegistration* registration,
1028                      GraphFloat32* graph, ObjectReader* reader) final {
1029     const TfLiteLSTMParams* tf_options;
1030     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1031     switch (tf_options->kernel_type) {
1032       case kTfLiteLSTMFullKernel:
1033         return ParseFull(tflite_node, registration, graph, reader, tf_options);
1034       case kTfLiteLSTMBasicKernel:
1035         return ParseBasic(tflite_node, registration, graph, reader, tf_options);
1036     }
1037   }
1038 
GetNewValueIdsForVariableInputNodes()1039   absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
1040       final {
1041     return new_variable_input_value_map_;
1042   }
1043 
1044  private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1045   absl::Status ParseBasic(const TfLiteNode* tflite_node,
1046                           const TfLiteRegistration* registration,
1047                           GraphFloat32* graph, ObjectReader* reader,
1048                           const TfLiteLSTMParams* tf_options) {
1049     if (tflite_node->inputs->size != 5) {
1050       return absl::InvalidArgumentError("LSTM should have 5 input tensors");
1051     }
1052     if (tflite_node->outputs->size != 4) {
1053       return absl::InvalidArgumentError("LSTM should have 4 output tensors");
1054     }
1055     RETURN_IF_ERROR(CheckBasicParameters(tf_options));
1056 
1057     Node* concat_node = graph->NewNode();
1058     concat_node->operation.type = ToString(OperationType::CONCAT);
1059     ConcatAttributes concat_attr;
1060     concat_attr.axis = Axis::CHANNELS;
1061     concat_node->operation.attributes = concat_attr;
1062 
1063     Node* fc_node = graph->NewNode();
1064     fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
1065     FullyConnectedAttributes fc_attr;
1066     RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
1067     fc_node->operation.attributes = std::move(fc_attr);
1068 
1069     Node* lstm_node = graph->NewNode();
1070     lstm_node->operation.type = ToString(OperationType::LSTM);
1071     LstmAttributes lstm_attr;
1072     lstm_attr.kernel_type = LstmKernelType::BASIC;
1073     lstm_node->operation.attributes = lstm_attr;
1074 
1075     Value* concat_temp;
1076     int concat_tensor_idx = tflite_node->outputs->data[2];
1077     RETURN_IF_ERROR(
1078         reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
1079     Value* activ_temp;
1080     int activ_tensor_idx = tflite_node->outputs->data[3];
1081     RETURN_IF_ERROR(
1082         reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
1083 
1084     RETURN_IF_ERROR(reader->AddInput(concat_node, 0));  // input
1085     RETURN_IF_ERROR(reader->AddInput(concat_node, 1));  // prev_activ
1086     RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
1087 
1088     RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
1089     RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
1090 
1091     RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
1092     RETURN_IF_ERROR(reader->AddInput(lstm_node, 4));   // prev_state
1093     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1));  // new_state
1094     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0));  // activation
1095 
1096     return absl::OkStatus();
1097   }
1098 
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1099   absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1100     if (tf_options->activation != kTfLiteActTanh) {
1101       return absl::UnimplementedError("Only TANH activation is supported.");
1102     }
1103     if (tf_options->cell_clip != 0.0f) {
1104       return absl::UnimplementedError("cell_clip is not supported.");
1105     }
1106     if (tf_options->proj_clip != 0.0f) {
1107       return absl::UnimplementedError("proj_clip is not supported.");
1108     }
1109     return absl::OkStatus();
1110   }
1111 
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1112   absl::Status ParseFull(const TfLiteNode* tflite_node,
1113                          const TfLiteRegistration* registration,
1114                          GraphFloat32* graph, ObjectReader* reader,
1115                          const TfLiteLSTMParams* tf_options) {
1116     // Invoke full LSTM parser
1117     RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1118                                         reader, tf_options,
1119                                         &new_variable_input_value_map_));
1120     return absl::OkStatus();
1121   }
1122 
CheckFullParameters(const TfLiteLSTMParams * tf_options)1123   absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1124     if (tf_options->activation != kTfLiteActSigmoid &&
1125         tf_options->activation != kTfLiteActTanh) {
1126       return absl::UnimplementedError(
1127           "Only sigmoid or tanh activation is supported.");
1128     }
1129 
1130     return absl::OkStatus();
1131   }
1132 
1133   absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1134 };
1135 
1136 class MulOperationParser : public TFLiteOperationParser {
1137  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1138   absl::Status IsSupported(const TfLiteContext* context,
1139                            const TfLiteNode* tflite_node,
1140                            const TfLiteRegistration* registration) final {
1141     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1142     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1143   }
1144 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1145   absl::Status Parse(const TfLiteNode* tflite_node,
1146                      const TfLiteRegistration* registration,
1147                      GraphFloat32* graph, ObjectReader* reader) final {
1148     const TfLiteTensor* input0 = reader->GetInputTensor(0);
1149     if (!input0) {
1150       return absl::InvalidArgumentError(
1151           "Couldn't get the 1st input tensor for MUL.");
1152     }
1153     const TfLiteTensor* input1 = reader->GetInputTensor(1);
1154     if (!input1) {
1155       return absl::InvalidArgumentError(
1156           "Couldn't get the 2nd input tensor for MUL.");
1157     }
1158     const bool constant_tensor0 = IsConstantTensor(input0);
1159     const bool constant_tensor1 = IsConstantTensor(input1);
1160     if (constant_tensor0 && constant_tensor1) {
1161       return absl::InvalidArgumentError("No runtime input tensors for MUL.");
1162     }
1163     const bool runtime_tensor0 = !constant_tensor0;
1164     const bool runtime_tensor1 = !constant_tensor1;
1165 
1166     Node* node = graph->NewNode();
1167     node->operation.type = ToString(OperationType::MUL);
1168     RETURN_IF_ERROR(reader->AddOutputs(node));
1169 
1170     // Determine runtime/constant tensors.
1171     if (runtime_tensor0 && runtime_tensor1) {
1172       if (input0 == input1) {
1173         // replace MUL(A, A) with POW(A, 2.0)
1174         // TODO(b/166831113): Support the same inputs for operations.
1175         node->operation.type = ToString(OperationType::POW);
1176         ElementwiseAttributes attr;
1177         attr.param = 2.0f;
1178         node->operation.attributes = std::move(attr);
1179         return reader->AddInput(node, 0);
1180       }
1181 
1182       // The "larger" input tensor must be bound to 1st input and the "smaller"
1183       // input tensor must be bound to 2nd input.
1184       BHWC shape0;
1185       RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1186       BHWC shape1;
1187       RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1188       int input_tensor0 = 0;
1189       int input_tensor1 = 1;
1190       if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1191           shape0.c == shape1.c) {
1192         input_tensor0 = 1;
1193         input_tensor1 = 0;
1194       }
1195       RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1196       RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1197     } else {
1198       ElementwiseAttributes attr;
1199       RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1200       node->operation.attributes = std::move(attr);
1201     }
1202 
1203     const TfLiteMulParams* tf_options;
1204     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1205     return MaybeFuseActivation(tf_options->activation, graph, node);
1206   }
1207 };
1208 
1209 class PackOperationParser : public TFLiteOperationParser {
1210  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1211   absl::Status IsSupported(const TfLiteContext* context,
1212                            const TfLiteNode* tflite_node,
1213                            const TfLiteRegistration* registration) final {
1214     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1215   }
1216 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1217   absl::Status Parse(const TfLiteNode* tflite_node,
1218                      const TfLiteRegistration* registration,
1219                      GraphFloat32* graph, ObjectReader* reader) final {
1220     if (tflite_node->inputs->size == 1) {
1221       // Pack with single input can be replaced with Reshape
1222       Node* node = graph->NewNode();
1223       node->operation.type = ToString(OperationType::RESHAPE);
1224       RETURN_IF_ERROR(reader->AddInput(node, 0));
1225       RETURN_IF_ERROR(reader->AddOutputs(node));
1226       // New shape comes from output shape.
1227       ReshapeAttributes attr;
1228       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1229       node->operation.attributes = attr;
1230       return absl::OkStatus();
1231     } else {
1232       // Pack with few inputs can be replaced with Concat
1233       const TfLitePackParams* tf_options;
1234       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1235 
1236       // Read inputs first to make sure const node is added to a graph before
1237       // concat node to ensure topological order.
1238       std::vector<const Value*> inputs;
1239       for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1240         Value* value;
1241         const auto status = reader->ReadValue(idx, &value);
1242         if (status.ok()) {
1243           inputs.push_back(value);
1244         } else {
1245           TensorFloat32 tensor;
1246           RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1247           Value* value;
1248           RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1249           inputs.push_back(value);
1250         }
1251       }
1252 
1253       Node* node = graph->NewNode();
1254       node->operation.type = ToString(OperationType::CONCAT);
1255       RETURN_IF_ERROR(reader->AddOutputs(node));
1256       for (const Value* input : inputs) {
1257         RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1258       }
1259       const TfLiteTensor* output = reader->GetOutputTensor(0);
1260       ConcatAttributes attr;
1261       RETURN_IF_ERROR(
1262           ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1263       node->operation.attributes = attr;
1264       return absl::OkStatus();
1265     }
1266   }
1267 };
1268 
1269 class PReLUOperationParser : public TFLiteOperationParser {
1270  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1271   absl::Status IsSupported(const TfLiteContext* context,
1272                            const TfLiteNode* tflite_node,
1273                            const TfLiteRegistration* registration) final {
1274     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1275     // TODO(eignasheva): add params check
1276     return absl::OkStatus();
1277   }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1278   absl::Status Parse(const TfLiteNode* tflite_node,
1279                      const TfLiteRegistration* registration,
1280                      GraphFloat32* graph, ObjectReader* reader) final {
1281     Node* node = graph->NewNode();
1282     node->operation.type = ToString(OperationType::PRELU);
1283     RETURN_IF_ERROR(reader->AddInput(node, 0));
1284     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1285 
1286     PReLUAttributes attr;
1287     Tensor<Linear, DataType::FLOAT32> linear_alpha;
1288     absl::Status status = reader->ReadTensor(1, &linear_alpha);
1289     if (status.ok()) {
1290       if (linear_alpha.shape.v != input_shape.c) {
1291         return absl::InvalidArgumentError(
1292             "Linear alpha shape does not match the number of input channels.");
1293       }
1294       attr.alpha = std::move(linear_alpha);
1295     } else {
1296       Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1297       RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1298       if (hwc_alpha.shape.h != input_shape.h ||
1299           hwc_alpha.shape.w != input_shape.w ||
1300           hwc_alpha.shape.c != input_shape.c) {
1301         return absl::InvalidArgumentError(
1302             "Alpha shape does not match input shape.");
1303       }
1304       attr.alpha = std::move(hwc_alpha);
1305     }
1306     node->operation.attributes = std::move(attr);
1307     return reader->AddOutputs(node);
1308   }
1309 };
1310 
1311 class PadOperationParser : public TFLiteOperationParser {
1312  public:
PadOperationParser(bool mirror_pad)1313   explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1314 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1315   absl::Status IsSupported(const TfLiteContext* context,
1316                            const TfLiteNode* tflite_node,
1317                            const TfLiteRegistration* registration) final {
1318     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1319     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1320   }
1321 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1322   absl::Status Parse(const TfLiteNode* tflite_node,
1323                      const TfLiteRegistration* registration,
1324                      GraphFloat32* graph, ObjectReader* reader) final {
1325     Node* node = graph->NewNode();
1326     node->operation.type = ToString(OperationType::PAD);
1327     RETURN_IF_ERROR(reader->AddInput(node, 0));
1328     RETURN_IF_ERROR(reader->AddOutputs(node));
1329 
1330     PadAttributes attr;
1331     if (mirror_pad_) {
1332       attr.type = PaddingContentType::REFLECT;
1333     } else /*zero pad*/ {
1334       attr.type = PaddingContentType::ZEROS;
1335     }
1336 
1337     Tensor<HW, DataType::INT32> paddings;
1338     RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1339 
1340     if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1341       // 4x2 tensor with paddings.
1342       attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1343                             paddings.data[4], paddings.data[6]);
1344       attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1345                            paddings.data[7]);
1346     } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1347       // 3x2 tensor with paddings.
1348       attr.prepended =
1349           BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1350       attr.appended =
1351           BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1352     } else {
1353       // It shouldn't fail here since it's checked at IsSupported().
1354       return absl::InvalidArgumentError(
1355           "Paddings tensor has unexpected shape.");
1356     }
1357     node->operation.attributes = attr;
1358     return absl::OkStatus();
1359   }
1360 
1361  private:
1362   bool mirror_pad_ = false;
1363 };
1364 
1365 class Pooling2DOperationParser : public TFLiteOperationParser {
1366  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1367   absl::Status IsSupported(const TfLiteContext* context,
1368                            const TfLiteNode* tflite_node,
1369                            const TfLiteRegistration* registration) final {
1370     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1371     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1372   }
1373 
1374  public:
Pooling2DOperationParser(PoolingType type)1375   explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1376 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1377   absl::Status Parse(const TfLiteNode* tflite_node,
1378                      const TfLiteRegistration* registration,
1379                      GraphFloat32* graph, ObjectReader* reader) final {
1380     Node* node = graph->NewNode();
1381     node->operation.type = ToString(OperationType::POOLING_2D);
1382     RETURN_IF_ERROR(reader->AddInput(node, 0));
1383     RETURN_IF_ERROR(reader->AddOutput(node, 0));
1384 
1385     Pooling2DAttributes attr;
1386     attr.type = type_;
1387 
1388     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1389 
1390     // Check whether there are custom options encoded. It happens if operation
1391     // is MaxPoolingWithArgmax2D. There is no way to read
1392     // tflite_node->builtin_code, so, simply check whether custom data is
1393     // available.
1394     const TfLitePoolParams* tf_options;
1395     if (!RetrieveCustomInitialData(tflite_node, &tf_options).ok()) {
1396       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1397     }
1398 
1399     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1400     // Second output is optional. It is not required, it but must be added after
1401     // MaybeAddFusedActivation function is called
1402     reader->AddOutput(node, 1).IgnoreError();
1403 
1404     // First output is the result of pooling operation, while second output is
1405     // indices used for pooling.
1406     auto outputs = graph->FindOutputs(node->id);
1407     attr.output_indices = outputs.size() == 2;
1408     if (attr.output_indices) {
1409       // Fix data type for output indices. In the model it is set as float32.
1410       outputs[1]->tensor.type = DataType::INT32;
1411     }
1412     RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1413     node->operation.attributes = attr;
1414     return absl::OkStatus();
1415   }
1416 
1417  private:
1418   const PoolingType type_;
1419 };
1420 
1421 class ReduceOperationParser : public TFLiteOperationParser {
1422  public:
ReduceOperationParser(OperationType operation_type)1423   explicit ReduceOperationParser(OperationType operation_type)
1424       : operation_type_(operation_type) {}
1425 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1426   absl::Status IsSupported(const TfLiteContext* context,
1427                            const TfLiteNode* tflite_node,
1428                            const TfLiteRegistration* registration) final {
1429     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1430   }
1431 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1432   absl::Status Parse(const TfLiteNode* tflite_node,
1433                      const TfLiteRegistration* registration,
1434                      GraphFloat32* graph, ObjectReader* reader) final {
1435     Node* node = graph->NewNode();
1436     node->operation.type = ToString(operation_type_);
1437     RETURN_IF_ERROR(reader->AddInput(node, 0));
1438     RETURN_IF_ERROR(reader->AddOutputs(node));
1439 
1440     const TfLiteReducerParams* tf_options;
1441     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1442 
1443     ReduceAttributes attr;
1444     const TfLiteTensor* input = reader->GetInputTensor(0);
1445     const TfLiteTensor* axes = reader->GetInputTensor(1);
1446     for (int i = 0; i < NumElements(axes->dims); i++) {
1447       Axis axis;
1448       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1449       attr.dims.insert(axis);
1450     }
1451     node->operation.attributes = attr;
1452     return absl::OkStatus();
1453   }
1454 
1455  private:
1456   const OperationType operation_type_;
1457 };
1458 
1459 class QuantizeOperationParser : public TFLiteOperationParser {
1460  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1461   absl::Status IsSupported(const TfLiteContext* context,
1462                            const TfLiteNode* tflite_node,
1463                            const TfLiteRegistration* registration) final {
1464     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1465     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1466   }
1467 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1468   absl::Status Parse(const TfLiteNode* tflite_node,
1469                      const TfLiteRegistration* registration,
1470                      GraphFloat32* graph, ObjectReader* reader) final {
1471     // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1472     // with floating-point versions of the original tensors.
1473     Node* node = graph->NewNode();
1474     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1475     RETURN_IF_ERROR(reader->AddInput(node, 0));
1476     RETURN_IF_ERROR(reader->AddOutputs(node));
1477 
1478     // Quantization attributes should already be present in the output tensor.
1479     auto output_value = graph->FindOutputs(node->id)[0];
1480     if (!output_value->quant_params) {
1481       return absl::InvalidArgumentError(
1482           "Encountered Quantize output with no quant params");
1483     }
1484     QuantizeAndDequantizeAttributes attr;
1485     attr.min = output_value->quant_params.value().min;
1486     attr.max = output_value->quant_params.value().max;
1487     attr.scale = output_value->quant_params.value().scale;
1488 
1489     node->operation.attributes = attr;
1490     return absl::OkStatus();
1491   }
1492 };
1493 
1494 class ReLUOperationParser : public TFLiteOperationParser {
1495  public:
ReLUOperationParser(int clip)1496   explicit ReLUOperationParser(int clip) : clip_(clip) {}
1497 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1498   absl::Status IsSupported(const TfLiteContext* context,
1499                            const TfLiteNode* tflite_node,
1500                            const TfLiteRegistration* registration) final {
1501     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1502     return absl::OkStatus();
1503   }
1504 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1505   absl::Status Parse(const TfLiteNode* tflite_node,
1506                      const TfLiteRegistration* registration,
1507                      GraphFloat32* graph, ObjectReader* reader) final {
1508     Node* node = graph->NewNode();
1509     node->operation.type = ToString(OperationType::RELU);
1510     RETURN_IF_ERROR(reader->AddInput(node, 0));
1511 
1512     ReLUAttributes attr;
1513     const TfLiteLeakyReluParams* tf_options;
1514     auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1515     attr.alpha = status.ok() ? tf_options->alpha : 0;
1516     attr.clip = clip_;
1517     node->operation.attributes = attr;
1518     return reader->AddOutputs(node);
1519   }
1520 
1521  private:
1522   const int clip_;
1523 };
1524 
1525 class ResamplerOperationParser : public TFLiteOperationParser {
1526  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1527   absl::Status IsSupported(const TfLiteContext* context,
1528                            const TfLiteNode* tflite_node,
1529                            const TfLiteRegistration* registration) final {
1530     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1531   }
1532 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1533   absl::Status Parse(const TfLiteNode* tflite_node,
1534                      const TfLiteRegistration* registration,
1535                      GraphFloat32* graph, ObjectReader* reader) final {
1536     Node* node = graph->NewNode();
1537     RETURN_IF_ERROR(reader->AddInput(node, 0));  // src
1538     RETURN_IF_ERROR(reader->AddInput(node, 1));  // warp
1539     RETURN_IF_ERROR(reader->AddOutputs(node));
1540 
1541     node->operation.type = ToString(OperationType::RESAMPLER);
1542 
1543     auto src_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1544     auto warp_shape = graph->FindInputs(node->id)[1]->tensor.shape;
1545 
1546     auto output_value = graph->FindOutputs(node->id)[0];
1547     output_value->tensor.shape =
1548         BHWC(src_shape.b, warp_shape.h, warp_shape.w, src_shape.c);
1549     return absl::OkStatus();
1550   }
1551 };
1552 
1553 class ReshapeOperationParser : public TFLiteOperationParser {
1554  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1555   absl::Status IsSupported(const TfLiteContext* context,
1556                            const TfLiteNode* tflite_node,
1557                            const TfLiteRegistration* registration) final {
1558     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1559     // TODO(eignasheva): add shape checking
1560     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1561   }
1562 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1563   absl::Status Parse(const TfLiteNode* tflite_node,
1564                      const TfLiteRegistration* registration,
1565                      GraphFloat32* graph, ObjectReader* reader) final {
1566     Node* node = graph->NewNode();
1567     node->operation.type = ToString(OperationType::RESHAPE);
1568     RETURN_IF_ERROR(reader->AddInput(node, 0));
1569     RETURN_IF_ERROR(reader->AddOutputs(node));
1570     // Here we may have extra inputs. Other tensors were supposed to
1571     // define new shape, but in TFLite these are ignored.
1572     // TODO(akulik): check that shapes match?
1573 
1574     // New shape comes from output shape.
1575     ReshapeAttributes attr;
1576     attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1577     node->operation.attributes = attr;
1578     return absl::OkStatus();
1579   }
1580 };
1581 
1582 class Resize2DOperationParser : public TFLiteOperationParser {
1583  public:
Resize2DOperationParser(SamplingType sampling_type)1584   explicit Resize2DOperationParser(SamplingType sampling_type)
1585       : sampling_type_(sampling_type) {}
1586 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1587   absl::Status IsSupported(const TfLiteContext* context,
1588                            const TfLiteNode* tflite_node,
1589                            const TfLiteRegistration* registration) final {
1590     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1591     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1592   }
1593 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1594   absl::Status Parse(const TfLiteNode* tflite_node,
1595                      const TfLiteRegistration* registration,
1596                      GraphFloat32* graph, ObjectReader* reader) final {
1597     Node* node = graph->NewNode();
1598     node->operation.type = ToString(OperationType::RESIZE);
1599     RETURN_IF_ERROR(reader->AddInput(node, 0));
1600     RETURN_IF_ERROR(reader->AddOutputs(node));
1601     // Here we may have extra inputs. Other tensors were supposed to
1602     // define new shape, but in TFLite these are ignored.
1603 
1604     Resize2DAttributes attr;
1605     RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1606     RETURN_IF_ERROR(
1607         GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1608     attr.type = sampling_type_;
1609     attr.new_shape.CopyAllDefinedAxis(
1610         graph->FindOutputs(node->id)[0]->tensor.shape);
1611     node->operation.attributes = attr;
1612     return absl::OkStatus();
1613   }
1614 
1615  private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1616   absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1617                                     bool* align_corners) {
1618     switch (sampling_type_) {
1619       case SamplingType::BILINEAR:
1620         return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1621             tflite_node, align_corners);
1622       case SamplingType::NEAREST:
1623         return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1624             tflite_node, align_corners);
1625       case SamplingType::UNKNOWN:
1626         return absl::InternalError("Sampling type is not specified");
1627     }
1628     return absl::OkStatus();
1629   }
1630 
1631   template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1632   absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1633                                            bool* align_corners) {
1634     const T* tf_options;
1635     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1636     *align_corners = tf_options->align_corners;
1637     return absl::OkStatus();
1638   }
1639 
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1640   absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1641                                         bool* half_pixel_centers) {
1642     if (sampling_type_ == SamplingType::BILINEAR) {
1643       const TfLiteResizeBilinearParams* tf_options;
1644       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1645       if (tf_options->align_corners && tf_options->half_pixel_centers) {
1646         return absl::InternalError(
1647             "If half_pixel_centers is True, align_corners must be False.");
1648       }
1649       *half_pixel_centers = tf_options->half_pixel_centers;
1650     } else {
1651       const TfLiteResizeNearestNeighborParams* tf_options;
1652       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1653       *half_pixel_centers = tf_options->half_pixel_centers;
1654     }
1655     return absl::OkStatus();
1656   }
1657 
1658   SamplingType sampling_type_ = SamplingType::UNKNOWN;
1659 };
1660 
1661 class SliceOperationParser : public TFLiteOperationParser {
1662  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1663   absl::Status IsSupported(const TfLiteContext* context,
1664                            const TfLiteNode* tflite_node,
1665                            const TfLiteRegistration* registration) final {
1666     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1667     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1668   }
1669 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1670   absl::Status Parse(const TfLiteNode* tflite_node,
1671                      const TfLiteRegistration* registration,
1672                      GraphFloat32* graph, ObjectReader* reader) final {
1673     Node* node = graph->NewNode();
1674     node->operation.type = ToString(OperationType::SLICE);
1675     RETURN_IF_ERROR(reader->AddOutputs(node));
1676     Value* input;
1677     RETURN_IF_ERROR(reader->ReadValue(0, &input));
1678     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1679 
1680     const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
1681     const int input_dims = tfl_input->dims->size;
1682 
1683     SliceAttributes attr;
1684     attr.strides = BHWC(1, 1, 1, 1);
1685     Tensor<Linear, DataType::INT32> starts, sizes;
1686     RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
1687     RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
1688     if (starts.data.size() != sizes.data.size()) {
1689       return absl::InvalidArgumentError("Starts amount != sizes amount.");
1690     }
1691     BHWC bhwc_starts(0, 0, 0, 0);
1692     BHWC bhwc_sizes = input->tensor.shape;
1693     if (input_dims == 4) {
1694       // input in BHWC layout
1695       if (starts.data.size() == 4) {
1696         bhwc_starts.b = starts.data[0];
1697         bhwc_starts.h = starts.data[1];
1698         bhwc_starts.w = starts.data[2];
1699         bhwc_starts.c = starts.data[3];
1700         bhwc_sizes.b = sizes.data[0];
1701         bhwc_sizes.h = sizes.data[1];
1702         bhwc_sizes.w = sizes.data[2];
1703         bhwc_sizes.c = sizes.data[3];
1704       } else if (starts.data.size() == 3) {
1705         // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
1706         bhwc_starts.h = starts.data[0];
1707         bhwc_starts.w = starts.data[1];
1708         bhwc_starts.c = starts.data[2];
1709         bhwc_sizes.h = sizes.data[0];
1710         bhwc_sizes.w = sizes.data[1];
1711         bhwc_sizes.c = sizes.data[2];
1712       } else {
1713         return absl::UnimplementedError(
1714             "Slicing is supported for 3 or 4 dimensional tensors only.");
1715       }
1716     } else if (input_dims == 3) {
1717       // input in BWC layout
1718       if (starts.data.size() == 3) {
1719         bhwc_starts.b = starts.data[0];
1720         bhwc_starts.w = starts.data[1];
1721         bhwc_starts.c = starts.data[2];
1722         bhwc_sizes.b = sizes.data[0];
1723         bhwc_sizes.w = sizes.data[1];
1724         bhwc_sizes.c = sizes.data[2];
1725       } else {
1726         return absl::UnimplementedError(
1727             "Slicing is supported for 3 or 4 dimensional tensors only.");
1728       }
1729     } else {
1730       return absl::UnimplementedError(
1731           "Slicing is supported for 3 or 4 dimensional tensors only.");
1732     }
1733     const auto& in_shape = input->tensor.shape;
1734     if (bhwc_sizes.b == -1) {
1735       bhwc_sizes.b = in_shape.b - bhwc_starts.b;
1736     }
1737     if (bhwc_sizes.h == -1) {
1738       bhwc_sizes.h = in_shape.h - bhwc_starts.h;
1739     }
1740     if (bhwc_sizes.w == -1) {
1741       bhwc_sizes.w = in_shape.w - bhwc_starts.w;
1742     }
1743     if (bhwc_sizes.c == -1) {
1744       bhwc_sizes.c = in_shape.c - bhwc_starts.c;
1745     }
1746     attr.starts = bhwc_starts;
1747     attr.ends =
1748         BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
1749              bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
1750     RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
1751 
1752     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1753     if ((attr.ends.b - attr.starts.b) != out_shape.b) {
1754       return absl::UnimplementedError("Output batch don't match");
1755     }
1756     if ((attr.ends.h - attr.starts.h) != out_shape.h) {
1757       return absl::UnimplementedError("Output height doesn't match");
1758     }
1759     if ((attr.ends.w - attr.starts.w) != out_shape.w) {
1760       return absl::UnimplementedError("Output width doesn't match");
1761     }
1762     if ((attr.ends.c - attr.starts.c) != out_shape.c) {
1763       return absl::UnimplementedError("Output channels don't match");
1764     }
1765     node->operation.attributes = attr;
1766     return absl::OkStatus();
1767   }
1768 
1769  private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1770   absl::Status UpdateIfNegative(const BHWC& input_shape,
1771                                 SliceAttributes* attr) {
1772     if (attr->ends.h < 0) {
1773       attr->ends.h = input_shape.h + attr->ends.h;
1774     }
1775     if (attr->ends.w < 0) {
1776       attr->ends.w = input_shape.w + attr->ends.w;
1777     }
1778     if (attr->ends.c < 0) {
1779       attr->ends.c = input_shape.c + attr->ends.c;
1780     }
1781     if (attr->ends.b < 0) {
1782       attr->ends.b = input_shape.b + attr->ends.b;
1783     }
1784     return absl::OkStatus();
1785   }
1786 };
1787 
1788 class SoftmaxOperationParser : public TFLiteOperationParser {
1789  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1790   absl::Status IsSupported(const TfLiteContext* context,
1791                            const TfLiteNode* tflite_node,
1792                            const TfLiteRegistration* registration) final {
1793     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1794     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1795   }
1796 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1797   absl::Status Parse(const TfLiteNode* tflite_node,
1798                      const TfLiteRegistration* registration,
1799                      GraphFloat32* graph, ObjectReader* reader) final {
1800     Node* node = graph->NewNode();
1801     node->operation.type = ToString(OperationType::SOFTMAX);
1802     RETURN_IF_ERROR(reader->AddInput(node, 0));
1803     RETURN_IF_ERROR(reader->AddOutputs(node));
1804 
1805     const TfLiteSoftmaxParams* tf_options;
1806     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1807     if (tf_options->beta != 1) {
1808       // there is multiply by scalar operation fused in softmax. Make a layer
1809       // out of it before softmax.
1810       return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1811       // auto mul_node = reader->NewPassthroughNode(node);
1812       // mul_node->operation.type = ToString(OperationType::MUL);
1813     }
1814     SoftmaxAttributes attr;
1815     attr.axis = Axis::CHANNELS;  // always by channels
1816     node->operation.attributes = attr;
1817     return absl::OkStatus();
1818   }
1819 };
1820 
1821 class SpaceToDepthOperationParser : public TFLiteOperationParser {
1822  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1823   absl::Status IsSupported(const TfLiteContext* context,
1824                            const TfLiteNode* tflite_node,
1825                            const TfLiteRegistration* registration) final {
1826     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1827     // TODO(impjdi): Dims check.
1828     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1829   }
1830 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1831   absl::Status Parse(const TfLiteNode* tflite_node,
1832                      const TfLiteRegistration* registration,
1833                      GraphFloat32* graph, ObjectReader* reader) final {
1834     Node* node = graph->NewNode();
1835     node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
1836     RETURN_IF_ERROR(reader->AddInput(node, 0));
1837     RETURN_IF_ERROR(reader->AddOutputs(node));
1838     const TfLiteSpaceToDepthParams* tf_options;
1839     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1840     SpaceToDepthAttributes attr;
1841     attr.block_size = tf_options->block_size;
1842     node->operation.attributes = attr;
1843     return absl::OkStatus();
1844   }
1845 };
1846 
1847 class SplitOperationParser : public TFLiteOperationParser {
1848  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1849   absl::Status IsSupported(const TfLiteContext* context,
1850                            const TfLiteNode* tflite_node,
1851                            const TfLiteRegistration* registration) final {
1852     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1853   }
1854 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1855   absl::Status Parse(const TfLiteNode* tflite_node,
1856                      const TfLiteRegistration* registration,
1857                      GraphFloat32* graph, ObjectReader* reader) final {
1858     const TfLiteSplitParams* split_params;
1859     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1860     if (split_params->num_splits == 1) {
1861       // Adding Identity reshape that will be removed.
1862       Node* node = graph->NewNode();
1863       node->operation.type = ToString(OperationType::RESHAPE);
1864       RETURN_IF_ERROR(reader->AddInput(node, 1));
1865       RETURN_IF_ERROR(reader->AddOutputs(node));
1866       // New shape comes from output shape.
1867       ReshapeAttributes attr;
1868       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1869       node->operation.attributes = attr;
1870       return absl::OkStatus();
1871     }
1872     const TfLiteTensor* input = reader->GetInputTensor(1);
1873     const TfLiteTensor* axis_tensor = reader->GetInputTensor(0);
1874     SplitAttributes attr;
1875     RETURN_IF_ERROR(
1876         ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1877 
1878     Node* node = graph->NewNode();
1879     node->operation.type = ToString(OperationType::SPLIT);
1880     node->operation.attributes = attr;
1881     RETURN_IF_ERROR(reader->AddInput(node, 1));
1882     for (int i = 0; i < tflite_node->outputs->size; ++i) {
1883       RETURN_IF_ERROR(reader->AddOutput(node, i));
1884     }
1885     return absl::OkStatus();
1886   }
1887 };
1888 
1889 class SplitVOperationParser : public TFLiteOperationParser {
1890  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1891   absl::Status IsSupported(const TfLiteContext* context,
1892                            const TfLiteNode* tflite_node,
1893                            const TfLiteRegistration* registration) final {
1894     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1895   }
1896 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1897   absl::Status Parse(const TfLiteNode* tflite_node,
1898                      const TfLiteRegistration* registration,
1899                      GraphFloat32* graph, ObjectReader* reader) final {
1900     const TfLiteSplitVParams* split_params;
1901     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1902     if (split_params->num_splits == 1) {
1903       // Adding Identity reshape that will be removed.
1904       Node* node = graph->NewNode();
1905       node->operation.type = ToString(OperationType::RESHAPE);
1906       RETURN_IF_ERROR(reader->AddInput(node, 0));
1907       RETURN_IF_ERROR(reader->AddOutputs(node));
1908       // New shape comes from output shape.
1909       ReshapeAttributes attr;
1910       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1911       node->operation.attributes = attr;
1912       return absl::OkStatus();
1913     }
1914     const TfLiteTensor* input = reader->GetInputTensor(0);
1915     const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
1916     SplitAttributes attr;
1917     RETURN_IF_ERROR(
1918         ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1919 
1920     Node* node = graph->NewNode();
1921     node->operation.type = ToString(OperationType::SPLIT);
1922     node->operation.attributes = attr;
1923     RETURN_IF_ERROR(reader->AddInput(node, 0));
1924     for (int i = 0; i < tflite_node->outputs->size; ++i) {
1925       RETURN_IF_ERROR(reader->AddOutput(node, i));
1926     }
1927     return absl::OkStatus();
1928   }
1929 };
1930 
1931 class StridedSliceOperationParser : public TFLiteOperationParser {
1932  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1933   absl::Status IsSupported(const TfLiteContext* context,
1934                            const TfLiteNode* tflite_node,
1935                            const TfLiteRegistration* registration) final {
1936     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1937     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
1938   }
1939 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1940   absl::Status Parse(const TfLiteNode* tflite_node,
1941                      const TfLiteRegistration* registration,
1942                      GraphFloat32* graph, ObjectReader* reader) final {
1943     Node* node = graph->NewNode();
1944     node->operation.type = ToString(OperationType::SLICE);
1945     RETURN_IF_ERROR(reader->AddOutputs(node));
1946     Value* input;
1947     RETURN_IF_ERROR(reader->ReadValue(0, &input));
1948     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1949 
1950     Tensor<Linear, DataType::INT32> tmp;
1951     RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
1952 
1953     bool read_without_batch = tmp.data.size() == 3;
1954     bool read_with_batch = tmp.data.size() == 4;
1955     if (!read_without_batch && !read_with_batch) {
1956       // Error: Must be catched in IsSupported()
1957       return absl::UnimplementedError(
1958           "Slicing is supported for 3 or 4 dimensional tensors only.");
1959     }
1960 
1961     const TfLiteStridedSliceParams* tf_options;
1962     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1963     RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1964 
1965     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1966 
1967     SliceAttributes attr;
1968     if (read_without_batch) {
1969       RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
1970                                               input->tensor.shape, &attr));
1971     }
1972     if (read_with_batch) {
1973       RETURN_IF_ERROR(
1974           ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
1975     }
1976     if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
1977         attr.strides.c == 0) {
1978       return absl::InvalidArgumentError("stride values must be non-zero");
1979     }
1980     if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
1981         attr.strides.c < 0) {
1982       return absl::UnimplementedError("Reverse slices are not supported.");
1983     }
1984     if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
1985         out_shape.b) {
1986       return absl::UnimplementedError("Output batch don't match");
1987     }
1988     if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
1989         out_shape.h) {
1990       return absl::UnimplementedError("Output height doesn't match");
1991     }
1992     if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
1993         out_shape.w) {
1994       return absl::UnimplementedError("Output width doesn't match");
1995     }
1996     if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
1997         out_shape.c) {
1998       return absl::UnimplementedError("Output channels don't match");
1999     }
2000     node->operation.attributes = attr;
2001     return absl::OkStatus();
2002   }
2003 
2004  private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)2005   absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
2006                               const BHWC& input_shape, int ignore_b,
2007                               int ignore_h, int ignore_w, int ignore_c,
2008                               SliceAttributes* attr) {
2009     if (tf_options->begin_mask & ignore_h) {
2010       attr->starts.h = 0;
2011     }
2012     if (tf_options->begin_mask & ignore_w) {
2013       attr->starts.w = 0;
2014     }
2015     if (tf_options->begin_mask & ignore_c) {
2016       attr->starts.c = 0;
2017     }
2018     if (tf_options->begin_mask & ignore_b) {
2019       attr->starts.b = 0;
2020     }
2021 
2022     if (tf_options->end_mask & ignore_h) {
2023       attr->ends.h = input_shape.h;
2024     }
2025     if (tf_options->end_mask & ignore_w) {
2026       attr->ends.w = input_shape.w;
2027     }
2028     if (tf_options->end_mask & ignore_c) {
2029       attr->ends.c = input_shape.c;
2030     }
2031     if (tf_options->end_mask & ignore_b) {
2032       attr->ends.b = input_shape.b;
2033     }
2034     return absl::OkStatus();
2035   }
2036 
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)2037   absl::Status UpdateIfNegative(const BHWC& input_shape,
2038                                 SliceAttributes* attr) {
2039     if (attr->ends.h < 0) {
2040       attr->ends.h = input_shape.h + attr->ends.h;
2041     }
2042     if (attr->ends.w < 0) {
2043       attr->ends.w = input_shape.w + attr->ends.w;
2044     }
2045     if (attr->ends.c < 0) {
2046       attr->ends.c = input_shape.c + attr->ends.c;
2047     }
2048     if (attr->ends.b < 0) {
2049       attr->ends.b = input_shape.b + attr->ends.b;
2050     }
2051 
2052     if (attr->starts.h < 0) {
2053       attr->starts.h = input_shape.h + attr->starts.h;
2054     }
2055     if (attr->starts.w < 0) {
2056       attr->starts.w = input_shape.w + attr->starts.w;
2057     }
2058     if (attr->starts.c < 0) {
2059       attr->starts.c = input_shape.c + attr->starts.c;
2060     }
2061     if (attr->starts.b < 0) {
2062       attr->starts.b = input_shape.b + attr->starts.b;
2063     }
2064 
2065     return absl::OkStatus();
2066   }
2067 
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2068   absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2069                                     const TfLiteStridedSliceParams* tf_options,
2070                                     const BHWC& input_shape,
2071                                     SliceAttributes* attr) {
2072     auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2073       Tensor<Linear, DataType::INT32> t;
2074       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2075       *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2076       return absl::OkStatus();
2077     };
2078 
2079     RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2080     RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2081     RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2082     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2083     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2084     return absl::OkStatus();
2085   }
2086 
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2087   absl::Status ReadAttribsWithoutBatch(
2088       const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2089       const BHWC& input_shape, SliceAttributes* attr) {
2090     auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2091       Tensor<Linear, DataType::INT32> t;
2092       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2093       *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2094       return absl::OkStatus();
2095     };
2096 
2097     RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2098     RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2099     RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2100     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2101     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2102     attr->starts.b = 0;
2103     attr->ends.b = input_shape.b;
2104     attr->strides.b = 1;
2105     return absl::OkStatus();
2106   }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2107   absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2108     if (tf_options->ellipsis_mask) {
2109       return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2110     }
2111     if (tf_options->new_axis_mask) {
2112       return absl::UnimplementedError("Slice does not support new_axis_mask.");
2113     }
2114     if (tf_options->shrink_axis_mask) {
2115       return absl::UnimplementedError(
2116           "Slice does not support shrink_axis_mask parameter. ");
2117     }
2118     return absl::OkStatus();
2119   }
2120 };
2121 
2122 class TileOperationParser : public TFLiteOperationParser {
2123  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2124   absl::Status IsSupported(const TfLiteContext* context,
2125                            const TfLiteNode* tflite_node,
2126                            const TfLiteRegistration* registration) final {
2127     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2128   }
2129 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2130   absl::Status Parse(const TfLiteNode* tflite_node,
2131                      const TfLiteRegistration* registration,
2132                      GraphFloat32* graph, ObjectReader* reader) final {
2133     Node* node = graph->NewNode();
2134     node->operation.type = ToString(OperationType::TILE);
2135     RETURN_IF_ERROR(reader->AddInput(node, 0));
2136     RETURN_IF_ERROR(reader->AddOutputs(node));
2137     return absl::OkStatus();
2138   }
2139 };
2140 
2141 // Builtin op version of TRANSPOSE_CONV.
2142 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2143  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2144   absl::Status IsSupported(const TfLiteContext* context,
2145                            const TfLiteNode* tflite_node,
2146                            const TfLiteRegistration* registration) final {
2147     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2148     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2149   }
2150 
2151   // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2152   // input, and an optional bias) and allows configurable padding & stride.
2153   // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2154   absl::Status Parse(const TfLiteNode* tflite_node,
2155                      const TfLiteRegistration* registration,
2156                      GraphFloat32* graph, ObjectReader* reader) final {
2157     auto* node = graph->NewNode();
2158     node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2159     Value* input;
2160     RETURN_IF_ERROR(reader->ReadValue(2, &input));
2161     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2162     RETURN_IF_ERROR(reader->AddOutputs(node));
2163 
2164     const TfLiteTransposeConvParams* tf_options;
2165     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2166 
2167     ConvolutionTransposedAttributes attr;
2168     attr.stride = tf_options
2169                       ? HW(tf_options->stride_height, tf_options->stride_width)
2170                       : HW(1, 1);
2171     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2172     if (runtime_inputs == 2) {
2173       RETURN_IF_ERROR(reader->AddInput(node, 1));
2174       auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2175       attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2176                                 weights_shape.w, weights_shape.c);
2177     } else {  // runtime_inputs == 1;
2178       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2179     }
2180     reader->ReadTensor(3, &attr.bias).IgnoreError();  // bias is optional
2181 
2182     UpdatePadding(tf_options->padding,
2183                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2184     node->operation.attributes = std::move(attr);
2185     return absl::OkStatus();
2186   }
2187 };
2188 
2189 // Custom op version of TRANSPOSE_CONV.
2190 class TransposeConvCustomOperationParser : public TFLiteOperationParser {
2191  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2192   absl::Status IsSupported(const TfLiteContext* context,
2193                            const TfLiteNode* tflite_node,
2194                            const TfLiteRegistration* registration) final {
2195     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2196   }
2197 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2198   absl::Status Parse(const TfLiteNode* tflite_node,
2199                      const TfLiteRegistration* registration,
2200                      GraphFloat32* graph, ObjectReader* reader) final {
2201     auto* node = graph->NewNode();
2202     node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2203     RETURN_IF_ERROR(reader->AddInput(node, 0));
2204     RETURN_IF_ERROR(reader->AddOutputs(node));
2205 
2206     const TfLiteTransposeConvParams* tf_options;
2207     auto status = RetrieveCustomInitialData(tflite_node, &tf_options);
2208 
2209     ConvolutionTransposedAttributes attr;
2210     attr.stride = status.ok()
2211                       ? HW(tf_options->stride_height, tf_options->stride_width)
2212                       : HW(1, 1);
2213     RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2214     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
2215 
2216     UpdatePadding(status.ok() ? tf_options->padding : kTfLitePaddingUnknown,
2217                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2218     node->operation.attributes = std::move(attr);
2219     return absl::OkStatus();
2220   }
2221 };
2222 
2223 class TransposeOperationParser : public TFLiteOperationParser {
2224  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2225   absl::Status IsSupported(const TfLiteContext* context,
2226                            const TfLiteNode* tflite_node,
2227                            const TfLiteRegistration* registration) final {
2228     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2229     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2230   }
2231 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2232   absl::Status Parse(const TfLiteNode* tflite_node,
2233                      const TfLiteRegistration* registration,
2234                      GraphFloat32* graph, ObjectReader* reader) final {
2235     Node* node = graph->NewNode();
2236     node->operation.type = ToString(OperationType::TRANSPOSE);
2237     RETURN_IF_ERROR(reader->AddInput(node, 0));
2238     RETURN_IF_ERROR(reader->AddOutputs(node));
2239 
2240     TransposeAttributes attr;
2241     Tensor<Linear, DataType::INT32> perm;
2242     RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2243     std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2244                                          {Axis::HEIGHT, 1},
2245                                          {Axis::WIDTH, 2},
2246                                          {Axis::CHANNELS, 3}};
2247     if (perm.data.size() == 4) {
2248       attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2249     } else if (perm.data.size() == 3) {
2250       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2251                                          Axis::CHANNELS};
2252       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2253       attr.perm.h = 1;
2254       attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2255       attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2256     } else if (perm.data.size() == 2) {
2257       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2258       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2259       attr.perm.h = 1;
2260       attr.perm.w = 2;
2261       attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2262     } else {
2263       return absl::InvalidArgumentError(
2264           "Permutation for transpose is invalid.");
2265     }
2266 
2267     node->operation.attributes = attr;
2268     return absl::OkStatus();
2269   }
2270 };
2271 
2272 class Unpooling2DOperationParser : public TFLiteOperationParser {
2273  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2274   absl::Status IsSupported(const TfLiteContext* context,
2275                            const TfLiteNode* tflite_node,
2276                            const TfLiteRegistration* registration) final {
2277     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2278   }
2279 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2280   absl::Status Parse(const TfLiteNode* tflite_node,
2281                      const TfLiteRegistration* registration,
2282                      GraphFloat32* graph, ObjectReader* reader) final {
2283     Node* node = graph->NewNode();
2284     node->operation.type = ToString(OperationType::MAX_UNPOOLING_2D);
2285     RETURN_IF_ERROR(reader->AddInput(node, 0));
2286     RETURN_IF_ERROR(reader->AddInput(node, 1));
2287     RETURN_IF_ERROR(reader->AddOutputs(node));
2288     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
2289     MaxUnpooling2DAttributes attr;
2290 
2291     const TfLitePoolParams* tf_options;
2292     RETURN_IF_ERROR(RetrieveCustomInitialData(tflite_node, &tf_options));
2293 
2294     attr.kernel = ToHW(tf_options->filter_height, tf_options->filter_width);
2295     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
2296     UpdatePadding(tf_options->padding, input_shape, &attr);
2297 
2298     node->operation.attributes = attr;
2299 
2300     auto output_value = graph->FindOutputs(node->id)[0];
2301     output_value->tensor.shape = CalculateOutputShape(input_shape, attr);
2302     return absl::OkStatus();
2303   }
2304 };
2305 
2306 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2307 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2308  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2309   absl::Status IsSupported(const TfLiteContext* context,
2310                            const TfLiteNode* tflite_node,
2311                            const TfLiteRegistration* registration) final {
2312     return absl::OkStatus();
2313   }
2314 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2315   absl::Status Parse(const TfLiteNode* tflite_node,
2316                      const TfLiteRegistration* registration,
2317                      GraphFloat32* graph, ObjectReader* reader) final {
2318     auto* node = graph->NewNode();
2319     node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2320     RETURN_IF_ERROR(reader->AddInput(node, 0));
2321     RETURN_IF_ERROR(reader->AddOutputs(node));
2322 
2323     BatchToSpaceAttributes bs_attr;
2324     Tensor<Linear, DataType::INT32> block;
2325     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2326     if (block.shape.v != 2) {
2327       return absl::InternalError("Space has to be HxW.");
2328     }
2329     bs_attr.block.h = block.data[0];
2330     bs_attr.block.w = block.data[1];
2331 
2332     Tensor<HW, DataType::INT32> crop;
2333     RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2334     auto crop_shape = crop.shape;
2335     if (crop_shape.h != 2 && crop_shape.w != 2) {
2336       return absl::InternalError("Space has to be HxW.");
2337     }
2338 
2339     bs_attr.crop.prepended.h = crop.data[0];
2340     bs_attr.crop.prepended.w = crop.data[2];
2341 
2342     bs_attr.crop.appended.h = crop.data[1];
2343     bs_attr.crop.appended.w = crop.data[3];
2344 
2345     node->operation.attributes = std::move(bs_attr);
2346     return absl::OkStatus();
2347   }
2348 };
2349 
2350 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2351  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2352   absl::Status IsSupported(const TfLiteContext* context,
2353                            const TfLiteNode* tflite_node,
2354                            const TfLiteRegistration* registration) final {
2355     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2356   }
2357 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2358   absl::Status Parse(const TfLiteNode* tflite_node,
2359                      const TfLiteRegistration* registration,
2360                      GraphFloat32* graph, ObjectReader* reader) final {
2361     auto* node = graph->NewNode();
2362     node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2363     RETURN_IF_ERROR(reader->AddInput(node, 0));
2364     RETURN_IF_ERROR(reader->AddOutputs(node));
2365     SpaceToBatchAttributes sb_attr;
2366     Tensor<Linear, DataType::INT32> block;
2367     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2368     if (block.shape.v != 2) {
2369       return absl::InternalError("Space has to be HxW.");
2370     }
2371     sb_attr.block.h = block.data[0];
2372     sb_attr.block.w = block.data[1];
2373 
2374     Tensor<HW, DataType::INT32> padding;
2375     RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2376     auto padding_shape = padding.shape;
2377 
2378     if (padding_shape.h != 2 && padding_shape.w != 2) {
2379       return absl::InternalError("Space has to be HxW.");
2380     }
2381 
2382     sb_attr.padding.prepended.h = padding.data[0];
2383     sb_attr.padding.prepended.w = padding.data[2];
2384 
2385     sb_attr.padding.appended.h = padding.data[1];
2386     sb_attr.padding.appended.w = padding.data[3];
2387 
2388     node->operation.attributes = std::move(sb_attr);
2389     return absl::OkStatus();
2390   }
2391 };
2392 
2393 class MeanOperationParser : public TFLiteOperationParser {
2394  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2395   absl::Status IsSupported(const TfLiteContext* context,
2396                            const TfLiteNode* tflite_node,
2397                            const TfLiteRegistration* registration) final {
2398     return CheckGpuDelegateCompatibility(context, tflite_node, registration);
2399   }
2400 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2401   absl::Status Parse(const TfLiteNode* tflite_node,
2402                      const TfLiteRegistration* registration,
2403                      GraphFloat32* graph, ObjectReader* reader) final {
2404     auto* node = graph->NewNode();
2405     node->operation.type = ToString(OperationType::MEAN);
2406     RETURN_IF_ERROR(reader->AddInput(node, 0));
2407     RETURN_IF_ERROR(reader->AddOutputs(node));
2408 
2409     MeanAttributes attr;
2410     const TfLiteTensor* input = reader->GetInputTensor(0);
2411     const TfLiteTensor* axes = reader->GetInputTensor(1);
2412     for (int i = 0; i < NumElements(axes->dims); i++) {
2413       Axis axis;
2414       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2415       attr.dims.insert(axis);
2416     }
2417     node->operation.attributes = attr;
2418     return absl::OkStatus();
2419   }
2420 };
2421 
2422 class UnsupportedOperationParser : public TFLiteOperationParser {
2423  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2424   absl::Status IsSupported(const TfLiteContext* context,
2425                            const TfLiteNode* tflite_node,
2426                            const TfLiteRegistration* registration) final {
2427     return absl::UnimplementedError("Operation is not supported.");
2428   }
2429 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2430   absl::Status Parse(const TfLiteNode* tflite_node,
2431                      const TfLiteRegistration* registration,
2432                      GraphFloat32* graph, ObjectReader* reader) final {
2433     return absl::UnimplementedError("Operation is not supported.");
2434   }
2435 };
2436 
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false)2437 absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
2438                          const TfLiteRegistration* registration,
2439                          bool allow_quant_ops = false) {
2440   return NewOperationParser(registration, allow_quant_ops)
2441       ->IsSupported(context, node, registration);
2442 }
2443 
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,const std::vector<TfLiteType> & allowed_types)2444 bool IsAllAllowedTensors(TfLiteContext* context,
2445                          const TfLiteIntArray* tensor_indices,
2446                          const std::vector<TfLiteType>& allowed_types) {
2447   for (int i = 0; i < tensor_indices->size; ++i) {
2448     int tensor_idx = tensor_indices->data[i];
2449     if (tensor_idx == kTfLiteOptionalTensor) continue;
2450     const TfLiteTensor* t = &context->tensors[tensor_idx];
2451     if (t->dims && t->dims->size >= 5) {
2452       return false;
2453     }
2454     bool type_supported = false;
2455     for (auto allowed_type : allowed_types) {
2456       if (t->type == allowed_type) {
2457         type_supported = true;
2458         break;
2459       }
2460     }
2461     if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2462       return false;
2463     }
2464   }
2465   return true;
2466 }
2467 }  // namespace
2468 
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops)2469 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2470     const TfLiteRegistration* registration, bool allow_quant_ops) {
2471   const auto builtin_code = registration->builtin_code;
2472   switch (builtin_code) {
2473     case kTfLiteBuiltinAbs:
2474       return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2475     case kTfLiteBuiltinAdd:
2476       return std::make_unique<AddOperationParser>();
2477     case kTfLiteBuiltinAveragePool2d:
2478       return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2479     case kTfLiteBuiltinBatchMatmul:
2480       return std::make_unique<BatchedMatMulOperationParser>();
2481     case kTfLiteBuiltinCast:
2482       return std::make_unique<CastOperationParser>();
2483     case kTfLiteBuiltinConcatenation:
2484       return std::make_unique<ConcatenationOperationParser>();
2485     case kTfLiteBuiltinConv2d:
2486       return std::make_unique<Conv2DOperationParser>();
2487     case kTfLiteBuiltinCos:
2488       return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2489     case kTfLiteBuiltinDensify:
2490       return std::make_unique<DensifyOperationParser>();
2491     case kTfLiteBuiltinDepthwiseConv2d:
2492       return std::make_unique<DepthwiseConvolutionOperationParser>();
2493     case kTfLiteBuiltinDepthToSpace:
2494       return std::make_unique<DepthToSpaceOperationParser>();
2495     case kTfLiteBuiltinDequantize:
2496       if (allow_quant_ops) {
2497         return std::make_unique<DequantizeOperationParser>();
2498       }
2499       break;
2500     case kTfLiteBuiltinDiv:
2501       return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2502     case kTfLiteBuiltinEqual:
2503       return std::make_unique<ElementwiseOperationParser>(OperationType::EQUAL);
2504     case kTfLiteBuiltinElu:
2505       return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2506     case kTfLiteBuiltinExp:
2507       return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2508     case kTfLiteBuiltinFloor:
2509       return std::make_unique<ElementwiseOperationParser>(OperationType::FLOOR);
2510     case kTfLiteBuiltinFloorDiv:
2511       return std::make_unique<ElementwiseOperationParser>(
2512           OperationType::FLOOR_DIV);
2513     case kTfLiteBuiltinFloorMod:
2514       return std::make_unique<ElementwiseOperationParser>(
2515           OperationType::FLOOR_MOD);
2516     case kTfLiteBuiltinFullyConnected:
2517       return std::make_unique<FullyConnectedOperationParser>();
2518     case kTfLiteBuiltinGreater:
2519       return std::make_unique<ElementwiseOperationParser>(
2520           OperationType::GREATER);
2521     case kTfLiteBuiltinGreaterEqual:
2522       return std::make_unique<ElementwiseOperationParser>(
2523           OperationType::GREATER_EQUAL);
2524     case kTfLiteBuiltinHardSwish:
2525       return std::make_unique<HardSwishOperationParser>();
2526     case kTfLiteBuiltinLess:
2527       return std::make_unique<ElementwiseOperationParser>(OperationType::LESS);
2528     case kTfLiteBuiltinLessEqual:
2529       return std::make_unique<ElementwiseOperationParser>(
2530           OperationType::LESS_EQUAL);
2531     case kTfLiteBuiltinLogistic:
2532       return std::make_unique<ElementwiseOperationParser>(
2533           OperationType::SIGMOID);
2534     case kTfLiteBuiltinLog:
2535       return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
2536     case kTfLiteBuiltinLstm:
2537       return std::make_unique<LSTMOperationParser>();
2538     case kTfLiteBuiltinMaximum:
2539       return std::make_unique<ElementwiseOperationParser>(
2540           OperationType::MAXIMUM);
2541     case kTfLiteBuiltinMaxPool2d:
2542       return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2543     case kTfLiteBuiltinMean:
2544       return std::make_unique<MeanOperationParser>();
2545     case kTfLiteBuiltinMinimum:
2546       return std::make_unique<ElementwiseOperationParser>(
2547           OperationType::MINIMUM);
2548     case kTfLiteBuiltinMirrorPad:
2549       return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
2550     case kTfLiteBuiltinMul:
2551       return std::make_unique<MulOperationParser>();
2552     case kTfLiteBuiltinNeg:
2553       return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
2554     case kTfLiteBuiltinNotEqual:
2555       return std::make_unique<ElementwiseOperationParser>(
2556           OperationType::NOT_EQUAL);
2557     case kTfLiteBuiltinPack:
2558       return std::make_unique<PackOperationParser>();
2559     case kTfLiteBuiltinPad:
2560       return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
2561     case kTfLiteBuiltinPow:
2562       return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
2563     case kTfLiteBuiltinReduceMax:
2564       return std::make_unique<ReduceOperationParser>(
2565           OperationType::REDUCE_MAXIMUM);
2566     case kTfLiteBuiltinReduceMin:
2567       return std::make_unique<ReduceOperationParser>(
2568           OperationType::REDUCE_MINIMUM);
2569     case kTfLiteBuiltinReduceProd:
2570       return std::make_unique<ReduceOperationParser>(
2571           OperationType::REDUCE_PRODUCT);
2572     case kTfLiteBuiltinQuantize:
2573       if (allow_quant_ops) {
2574         return std::make_unique<QuantizeOperationParser>();
2575       }
2576       break;
2577     case kTfLiteBuiltinRelu:
2578       return std::make_unique<ReLUOperationParser>(0);
2579     case kTfLiteBuiltinRelu6:
2580       return std::make_unique<ReLUOperationParser>(6);
2581     case kTfLiteBuiltinReluN1To1:
2582       return std::make_unique<ClampOperationsParser>(-1.0, 1.0);
2583     case kTfLiteBuiltinLeakyRelu:
2584       return std::make_unique<ReLUOperationParser>(0);
2585     case kTfLiteBuiltinPrelu:
2586       return std::make_unique<PReLUOperationParser>();
2587     case kTfLiteBuiltinReshape:
2588       return std::make_unique<ReshapeOperationParser>();
2589     case kTfLiteBuiltinResizeBilinear:
2590       return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
2591     case kTfLiteBuiltinResizeNearestNeighbor:
2592       return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
2593     case kTfLiteBuiltinRsqrt:
2594       return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
2595     case kTfLiteBuiltinSin:
2596       return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
2597     case kTfLiteBuiltinSlice:
2598       return std::make_unique<SliceOperationParser>();
2599     case kTfLiteBuiltinSoftmax:
2600       return std::make_unique<SoftmaxOperationParser>();
2601     case kTfLiteBuiltinSpaceToDepth:
2602       return std::make_unique<SpaceToDepthOperationParser>();
2603     case kTfLiteBuiltinSplit:
2604       return std::make_unique<SplitOperationParser>();
2605     case kTfLiteBuiltinSplitV:
2606       return std::make_unique<SplitVOperationParser>();
2607     case kTfLiteBuiltinSqrt:
2608       return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
2609     case kTfLiteBuiltinSquare:
2610       return std::make_unique<ElementwiseOperationParser>(
2611           OperationType::SQUARE);
2612     case kTfLiteBuiltinSquaredDifference:
2613       return std::make_unique<ElementwiseOperationParser>(
2614           OperationType::SQUARED_DIFF);
2615     case kTfLiteBuiltinStridedSlice:
2616       return std::make_unique<StridedSliceOperationParser>();
2617     case kTfLiteBuiltinSub:
2618       return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
2619     case kTfLiteBuiltinSum:
2620       return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
2621     case kTfLiteBuiltinTanh:
2622       return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
2623     case kTfLiteBuiltinTile:
2624       return std::make_unique<TileOperationParser>();
2625     case kTfLiteBuiltinTranspose:
2626       return std::make_unique<TransposeOperationParser>();
2627     case kTfLiteBuiltinTransposeConv:
2628       return std::make_unique<TransposeConvBuiltinOperationParser>();
2629     case kTfLiteBuiltinCustom: {
2630       const absl::string_view custom_name = registration->custom_name;
2631       if (custom_name == "Convolution2DTransposeBias") {
2632         return std::make_unique<TransposeConvCustomOperationParser>();
2633       }
2634       if (custom_name == "MaxPoolingWithArgmax2D") {
2635         return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2636       }
2637       if (custom_name == "MaxUnpooling2D") {
2638         return std::make_unique<Unpooling2DOperationParser>();
2639       }
2640       if (custom_name == "Resampler") {
2641         return std::make_unique<ResamplerOperationParser>();
2642       }
2643       return NewCustomOperationParser(registration->custom_name);
2644     }
2645   }
2646   return std::make_unique<UnsupportedOperationParser>();
2647 }
2648 
2649 // TODO(impjdi): Check number of input/output tensors and their dimensions.
2650 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions)2651 TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
2652                                 int max_delegated_partitions) {
2653   delegates::IsNodeSupportedFn node_supported_fn =
2654       [=](TfLiteContext* context, TfLiteNode* node,
2655           TfLiteRegistration* registration,
2656           std::string* unsupported_details) -> bool {
2657     const auto status =
2658         IsSupported(context, node, registration, allow_quant_ops);
2659     if (!status.ok()) {
2660       if (unsupported_details) {
2661         *unsupported_details = std::string(status.message());
2662       }
2663       return false;
2664     }
2665 
2666     std::vector<TfLiteType> allowed_in_types = {kTfLiteFloat32, kTfLiteFloat16};
2667     std::vector<TfLiteType> allowed_out_types = {kTfLiteFloat32,
2668                                                  kTfLiteFloat16};
2669     if (allow_quant_ops) {
2670       // Since we only check non-constant tensors, type cannot be Int32.
2671       allowed_in_types.push_back(kTfLiteInt8);
2672       allowed_in_types.push_back(kTfLiteUInt8);
2673       allowed_out_types.push_back(kTfLiteInt8);
2674       allowed_out_types.push_back(kTfLiteUInt8);
2675     }
2676     if (IsLogicalCode(registration->builtin_code)) {
2677       allowed_out_types.push_back(kTfLiteBool);
2678     }
2679     if (registration->builtin_code == kTfLiteBuiltinCast) {
2680       allowed_in_types.push_back(kTfLiteBool);
2681     }
2682     if (!IsAllAllowedTensors(context, node->inputs, allowed_in_types) ||
2683         !IsAllAllowedTensors(context, node->outputs, allowed_out_types)) {
2684       if (unsupported_details) {
2685         *unsupported_details =
2686             "OP is supported, but tensor type/shape doesn't supported.";
2687       }
2688       return false;
2689     }
2690     return true;
2691   };
2692 
2693   delegates::FP16GraphPartitionHelper partition_helper(context,
2694                                                        node_supported_fn);
2695   std::set<std::string> unsupported_nodes_info;
2696   if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
2697     return TfLiteIntArrayCreate(0);
2698   }
2699 
2700   // By default, we simply get 1st largest partition as 'max_delegate_partions'
2701   // is set to 1 by default.
2702   std::vector<int> ops_to_replace =
2703       partition_helper.GetNodesOfFirstNLargestPartitions(
2704           max_delegated_partitions);
2705 
2706   if (!unsupported_nodes_info.empty() &&
2707       partition_helper.num_total_nodes() > ops_to_replace.size()) {
2708     std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
2709     std::string error_message = absl::StrCat(
2710         "Following operations are not supported by GPU delegate:\n",
2711         unsupported, "\n");
2712     if (!ops_to_replace.empty()) {
2713       absl::StrAppend(
2714           &error_message, ops_to_replace.size(),
2715           " operations will run on the GPU, and the remaining ",
2716           partition_helper.num_total_nodes() - ops_to_replace.size());
2717     } else {
2718       absl::StrAppend(&error_message,
2719                       "No operations will run on the GPU, and all ",
2720                       partition_helper.num_total_nodes());
2721     }
2722     absl::StrAppend(&error_message, " operations will run on the CPU.");
2723     TF_LITE_KERNEL_LOG(context, error_message.c_str());
2724   }
2725   return ConvertVectorToTfLiteIntArray(ops_to_replace);
2726 }
2727 
2728 // Creates inputs and outputs passed by io_tensors parameters in the resulting
2729 // graph. We force it to make sure that delegated subgraph has same order of
2730 // inputs and outputs with the original one. When delegated model is built from
2731 // the tflite model representation tensors are created lazily, so there is no
2732 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,const std::vector<int> & io_ids,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)2733 absl::Status PrecreateIOTensors(
2734     TfLiteContext* context, GraphFloat32* graph, const std::vector<int>& io_ids,
2735     absl::flat_hash_map<int, int>* quant_conversion_map,
2736     absl::flat_hash_map<int, Value*>* tensor_to_value) {
2737   for (const auto& id : io_ids) {
2738     const TfLiteTensor& tflite_tensor = context->tensors[id];
2739     if (tflite::IsConstantTensor(&tflite_tensor)) continue;
2740     RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
2741         context, tensor_to_value, quant_conversion_map, graph, id));
2742   }
2743   return absl::OkStatus();
2744 }
2745 
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)2746 absl::Status CopyVariableTensorOutputs(
2747     TfLiteNode* tflite_node, TfLiteRegistration* registration,
2748     GraphFloat32* graph, ObjectReader& reader,
2749     const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
2750   absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
2751       new_variable_tensor_values);
2752   // Retrieve the final value id for the variable input tensors.
2753   for (int i = 0; i < tflite_node->inputs->size; i++) {
2754     int tensor_idx = tflite_node->inputs->data[i];
2755     Value* value;
2756     if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
2757     if (value->tensor.is_variable_input) {
2758       if (new_variable_tensor_values_copy.find(i) ==
2759           new_variable_tensor_values_copy.end()) {
2760         return absl::InvalidArgumentError(
2761             absl::StrCat(GetOpNameByRegistration(*registration),
2762                          " did not provide a new value for the variable input "
2763                          "tensor with index ",
2764                          tensor_idx));
2765       } else {
2766         Node* node = graph->NewNode();
2767         node->operation.type = ToString(OperationType::COPY);
2768         RETURN_IF_ERROR(graph->AddConsumer(
2769             node->id, new_variable_tensor_values_copy.at(i)));
2770         RETURN_IF_ERROR(reader.AddUpdate(node, i));
2771         new_variable_tensor_values_copy.erase(
2772             new_variable_tensor_values_copy.find(i));
2773       }
2774     }
2775   }
2776   if (!new_variable_tensor_values_copy.empty()) {
2777     return absl::InvalidArgumentError(
2778         "More input variable tensors asked to be copied than present on the "
2779         "node");
2780   }
2781   return absl::OkStatus();
2782 }
2783 
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2784 absl::Status BuildModel(TfLiteContext* context,
2785                         const TfLiteDelegateParams* delegate_params,
2786                         GraphFloat32* graph,
2787                         absl::flat_hash_map<int, int>* quant_conversion_map) {
2788   std::vector<int> inputs(delegate_params->input_tensors->size);
2789   std::vector<int> outputs(delegate_params->output_tensors->size);
2790   for (int i = 0; i < delegate_params->input_tensors->size; i++) {
2791     inputs[i] = delegate_params->input_tensors->data[i];
2792   }
2793   for (int i = 0; i < delegate_params->output_tensors->size; i++) {
2794     outputs[i] = delegate_params->output_tensors->data[i];
2795   }
2796   return BuildModelEnforceIO(context, delegate_params, inputs, outputs, graph,
2797                              quant_conversion_map);
2798 }
2799 
BuildModelEnforceIO(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,const std::vector<int> & input_ids,const std::vector<int> & output_ids,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2800 absl::Status BuildModelEnforceIO(
2801     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2802     const std::vector<int>& input_ids, const std::vector<int>& output_ids,
2803     GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2804   std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
2805   std::vector<int> tflite_nodes;
2806   for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
2807     TfLiteNode* tflite_node = nullptr;
2808     TfLiteRegistration* registration = nullptr;
2809     RETURN_IF_ERROR(GetNodeAndRegistration(
2810         context, delegate_params->nodes_to_replace->data[i], &tflite_node,
2811         &registration));
2812     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
2813         context->tensors[tflite_node->inputs->data[0]].type ==
2814             TfLiteType::kTfLiteFloat16 &&
2815         context->tensors[tflite_node->inputs->data[0]].allocation_type ==
2816             TfLiteAllocationType::kTfLiteMmapRo) {
2817       // Ignore Fp16 Dequantize nodes only if they are the final nodes before
2818       // weights, i.e. no other nodes preceded them (e.g. DENSIFY).
2819       continue;
2820     }
2821     auto op_parser = NewOperationParser(
2822         registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
2823     if (!op_parser) {
2824       return absl::UnimplementedError(
2825           absl::StrCat("Operation ", registration->builtin_code, "(",
2826                        registration->custom_name,
2827                        ") is not supported by TFLite GPU Delegate."));
2828     }
2829     operations.push_back(std::move(op_parser));
2830     tflite_nodes.push_back(i);
2831   }
2832   absl::flat_hash_map<int, Value*> tensor_to_value;
2833   std::vector<ValueId> variable_inputs_to_value_id;
2834 
2835   RETURN_IF_ERROR(PrecreateIOTensors(context, graph, input_ids,
2836                                      quant_conversion_map, &tensor_to_value));
2837   RETURN_IF_ERROR(PrecreateIOTensors(context, graph, output_ids,
2838                                      quant_conversion_map, &tensor_to_value));
2839   for (int i = 0; i < operations.size(); ++i) {
2840     TfLiteNode* tflite_node;
2841     TfLiteRegistration* registration;
2842     RETURN_IF_ERROR(GetNodeAndRegistration(
2843         context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
2844         &tflite_node, &registration));
2845     ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
2846                         quant_conversion_map);
2847     const auto status =
2848         operations[i]->Parse(tflite_node, registration, graph, &reader);
2849     if (!status.ok()) {
2850       return absl::InternalError(absl::StrCat(
2851           GetOpNameByRegistration(*registration), ": ", status.message()));
2852     }
2853 
2854     absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
2855         operations[i]->GetNewValueIdsForVariableInputNodes();
2856 
2857     RETURN_IF_ERROR(
2858         CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
2859                                   new_value_for_variable_input_tensors));
2860   }
2861 
2862   // Variable input tensors expect to be unchanged throughout model execution.
2863   // They need to be an output of the graph in order to have them unchanged.
2864   for (auto value_id : variable_inputs_to_value_id) {
2865     if (!graph->IsGraphOutput(value_id)) {
2866       return absl::InvalidArgumentError(
2867           absl::StrCat("Variable input tensors must be a graph output. Value ",
2868                        value_id, " is not a graph output"));
2869     }
2870   }
2871   return absl::OkStatus();
2872 }
2873 
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2874 absl::Status BuildFinalModel(
2875     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2876     GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2877   RETURN_IF_ERROR(
2878       BuildModel(context, delegate_params, graph, quant_conversion_map));
2879 
2880   // Apply general transformations on the graph.
2881   ModelTransformer transformer(graph);
2882   if (!ApplyModelTransformations(&transformer)) {
2883     return absl::InternalError("Graph transformations failed");
2884   }
2885   return absl::OkStatus();
2886 }
2887 
2888 namespace {
2889 
2890 class DelegateContext {
2891  public:
2892   struct DelegateData {
2893     std::vector<int> input_ids;
2894     std::vector<int> output_ids;
2895     GraphFloat32* graph;
2896     std::unique_ptr<absl::flat_hash_map<int, int>> quant_conversion_map;
2897   };
Init(TfLiteContext * context,const TfLiteDelegateParams * delegate_params)2898   bool Init(TfLiteContext* context,
2899             const TfLiteDelegateParams* delegate_params) {
2900     const auto* delegate_data =
2901         reinterpret_cast<DelegateData*>(delegate_params->delegate->data_);
2902     return delegate_data->graph &&
2903            BuildModelEnforceIO(context, delegate_params,
2904                                delegate_data->input_ids,
2905                                delegate_data->output_ids, delegate_data->graph,
2906                                delegate_data->quant_conversion_map.get())
2907                .ok();
2908   }
2909 };
2910 
DelegatePrepare(TfLiteContext * context,TfLiteDelegate * delegate)2911 TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
2912   TfLiteRegistration registration;
2913   registration.init = [](TfLiteContext* context, const char* buffer,
2914                          size_t) -> void* {
2915     auto* delegate_context = new DelegateContext();
2916     if (!delegate_context->Init(
2917             context, reinterpret_cast<const TfLiteDelegateParams*>(buffer))) {
2918       delete delegate_context;
2919       return nullptr;
2920     }
2921     return delegate_context;
2922   };
2923   registration.free = [](TfLiteContext* context, void* buffer) -> void {
2924     delete reinterpret_cast<DelegateContext*>(buffer);
2925   };
2926   registration.prepare = [](TfLiteContext* context,
2927                             TfLiteNode* node) -> TfLiteStatus {
2928     return node->user_data ? kTfLiteOk : kTfLiteError;
2929   };
2930   registration.invoke = nullptr;
2931   registration.custom_name = nullptr;
2932 
2933   const auto* delegate_data =
2934       reinterpret_cast<const DelegateContext::DelegateData*>(delegate->data_);
2935   TfLiteIntArray* ops_to_replace = GetOpsToReplace(
2936       context, static_cast<bool>(delegate_data->quant_conversion_map));
2937   const auto status = context->ReplaceNodeSubsetsWithDelegateKernels(
2938       context, registration, ops_to_replace, delegate);
2939   TfLiteIntArrayFree(ops_to_replace);
2940   return status;
2941 }
2942 
2943 }  // namespace
2944 
BuildFromFlatBuffer(const tflite::FlatBufferModel & flatbuffer,const tflite::OpResolver & op_resolver,GraphFloat32 * graph,bool allow_quant_ops)2945 absl::Status BuildFromFlatBuffer(const tflite::FlatBufferModel& flatbuffer,
2946                                  const tflite::OpResolver& op_resolver,
2947                                  GraphFloat32* graph, bool allow_quant_ops) {
2948   std::unique_ptr<tflite::Interpreter> interpreter;
2949   tflite::InterpreterBuilder interpreter_builder(flatbuffer, op_resolver);
2950   if (interpreter_builder(&interpreter) != kTfLiteOk || !interpreter) {
2951     return absl::InternalError("Unable to prepare TfLite interpreter.");
2952   }
2953   TfLiteDelegate delegate;
2954 
2955   DelegateContext::DelegateData delegate_data{interpreter->inputs(),
2956                                               interpreter->outputs(), graph};
2957   if (allow_quant_ops) {
2958     delegate_data.quant_conversion_map =
2959         absl::make_unique<absl::flat_hash_map<int, int>>();
2960   }
2961 
2962   delegate.data_ = &delegate_data;
2963   delegate.flags = kTfLiteDelegateFlagsNone;
2964   delegate.Prepare = DelegatePrepare;
2965   delegate.CopyFromBufferHandle = nullptr;
2966   delegate.CopyToBufferHandle = nullptr;
2967   delegate.FreeBufferHandle = nullptr;
2968 
2969   if (interpreter->ModifyGraphWithDelegate(&delegate) != kTfLiteOk) {
2970     return absl::InternalError("Conversion from TfLite model failed.");
2971   }
2972 
2973   ModelTransformer transformer(graph);
2974   if (!ApplyModelTransformations(&transformer)) {
2975     return absl::InternalError("Graph transformations failed");
2976   }
2977 
2978   return absl::OkStatus();
2979 }
2980 
2981 }  // namespace gpu
2982 }  // namespace tflite
2983