• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "tensorflow/lite/builtin_ops.h"
34 #include "tensorflow/lite/c/builtin_op_data.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
37 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
38 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
39 #include "tensorflow/lite/delegates/gpu/common/model.h"
40 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
41 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
42 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
43 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
44 #include "tensorflow/lite/delegates/gpu/common/operations.h"
45 #include "tensorflow/lite/delegates/gpu/common/shape.h"
46 #include "tensorflow/lite/delegates/gpu/common/status.h"
47 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
48 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
49 #include "tensorflow/lite/delegates/utils.h"
50 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
51 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
52 #include "tensorflow/lite/kernels/kernel_util.h"
53 #include "tensorflow/lite/util.h"
54 
55 namespace tflite {
56 namespace gpu {
57 namespace {
58 
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)59 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
60                                          int bias_tensor_id,
61                                          ObjectReader* reader,
62                                          FullyConnectedAttributes* attr) {
63   Tensor<HW, DataType::FLOAT32> weights;
64   RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
65   attr->weights.data = std::move(weights.data);
66   attr->weights.id = weights.id;
67   attr->weights.shape.h = 1;
68   attr->weights.shape.w = 1;
69   attr->weights.shape.o = weights.shape.h;
70   attr->weights.shape.i = weights.shape.w;
71   reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError();  // optional
72   return absl::OkStatus();
73 }
74 
75 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)76 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
77                                  const ParamsT** tf_options) {
78   *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
79   if (!*tf_options) {
80     return absl::InternalError("Unable to retrieve builtin_data.");
81   }
82   return absl::OkStatus();
83 }
84 
CheckDilation(int dilation_h,int dilation_w)85 absl::Status CheckDilation(int dilation_h, int dilation_w) {
86   if (dilation_h <= 0 || dilation_w <= 0) {
87     return absl::InvalidArgumentError(absl::StrCat(
88         "Incorrect dilation values: dilation_factor = ", dilation_h,
89         ", dilation_factor = ", dilation_w));
90   }
91   return absl::OkStatus();
92 }
93 
CheckStridesAndDilation(int strides_h,int strides_w,int dilation_h,int dilation_w)94 absl::Status CheckStridesAndDilation(int strides_h, int strides_w,
95                                      int dilation_h, int dilation_w) {
96   RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
97   RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w));
98   return absl::OkStatus();
99 }
100 
101 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)102 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
103   ConstTensorAttributes attr;
104   attr.tensor = std::move(t);
105   Node* node = graph->NewNode();
106   node->operation.attributes = attr;
107   node->operation.type = ToString(OperationType::CONSTANT);
108   *value = graph->NewValue();
109   RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
110   // Keep data inside this tensor.
111   (*value)->tensor.ref = attr.tensor.id;
112   (*value)->tensor.type = attr.tensor.kType;
113   (*value)->tensor.shape = attr.tensor.shape;
114   return absl::OkStatus();
115 }
116 
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)117 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
118                                         TensorOrScalar* tensor_or_scalar) {
119   const std::string& opname = node->operation.type;
120 
121   // Determine runtime/constant tensors.
122   const TfLiteTensor* input0 = reader->GetInputTensor(0);
123   if (!input0) {
124     return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
125                                       opname);
126   }
127   const TfLiteTensor* input1 = reader->GetInputTensor(1);
128   if (!input1) {
129     return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
130                                       opname);
131   }
132   const bool constant_tensor0 = IsConstantTensor(input0);
133   const bool constant_tensor1 = IsConstantTensor(input1);
134   if (constant_tensor0 && constant_tensor1) {
135     return absl::InvalidArgumentError("No runtime input tensors for " + opname);
136   }
137   const bool runtime_tensor0 = !constant_tensor0;
138   const bool runtime_tensor1 = !constant_tensor1;
139 
140   if (runtime_tensor0 && runtime_tensor1) {
141     RETURN_IF_ERROR(reader->AddInput(node, 0));
142     RETURN_IF_ERROR(reader->AddInput(node, 1));
143   } else {
144     int runtime_tensor = 0;
145     int constant_tensor = 1;
146     TfLiteIntArray* constant_dims = input1->dims;
147     if (constant_tensor0 && runtime_tensor1) {
148       runtime_tensor = 1;
149       constant_tensor = 0;
150       constant_dims = input0->dims;
151     }
152     RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
153     if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
154       Tensor<Scalar, DataType::FLOAT32> tensor;
155       RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
156       *tensor_or_scalar = tensor.data[0];
157     } else {
158       if (CheckIfLinearConvertible(constant_dims).ok()) {
159         Tensor<Linear, DataType::FLOAT32> tensor;
160         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
161         *tensor_or_scalar = std::move(tensor);
162       } else {
163         Tensor<HWC, DataType::FLOAT32> tensor;
164         RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
165         *tensor_or_scalar = std::move(tensor);
166       }
167     }
168   }
169   return absl::OkStatus();
170 }
171 
172 class AddOperationParser : public TFLiteOperationParser {
173  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)174   absl::Status IsSupported(const TfLiteContext* context,
175                            const TfLiteNode* tflite_node,
176                            const TfLiteRegistration* registration) final {
177     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
178     if (tflite_node->inputs->size != 2) {
179       return absl::UnimplementedError("ADD requires two input tensors.");
180     }
181     // TODO(eignasheva): Add shapes check.
182 
183     const TfLiteAddParams* tf_options;
184     return RetrieveBuiltinData(tflite_node, &tf_options);
185   }
186 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)187   absl::Status Parse(const TfLiteNode* tflite_node,
188                      const TfLiteRegistration* registration,
189                      GraphFloat32* graph, ObjectReader* reader) final {
190     // TFLite currently only supports 2 input ADDs.  Thus, the logic below only
191     // considers 2 input cases.  The underlying GPU shader programs can accept
192     // more inputs, but the logic below would have to be expanded.
193 
194     Node* node = graph->NewNode();
195     node->operation.type = ToString(OperationType::ADD);
196     RETURN_IF_ERROR(reader->AddOutputs(node));
197     ElementwiseAttributes attr;
198     RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
199     node->operation.attributes = std::move(attr);
200     const TfLiteAddParams* tf_options;
201     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
202     return MaybeFuseActivation(tf_options->activation, graph, node);
203   }
204 };
205 
206 class BatchedMatMulOperationParser : public TFLiteOperationParser {
207  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)208   absl::Status IsSupported(const TfLiteContext* context,
209                            const TfLiteNode* tflite_node,
210                            const TfLiteRegistration* registration) final {
211     return CheckInputsOutputs(context, tflite_node,
212                               /*runtime_inputs=*/2, /*outputs=*/1);
213   }
214 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)215   absl::Status Parse(const TfLiteNode* tflite_node,
216                      const TfLiteRegistration* registration,
217                      GraphFloat32* graph, ObjectReader* reader) final {
218     Node* node = graph->NewNode();
219     node->operation.type = ToString(OperationType::BATCHED_MATMUL);
220     RETURN_IF_ERROR(reader->AddInput(node, 0));
221     RETURN_IF_ERROR(reader->AddInput(node, 1));
222     RETURN_IF_ERROR(reader->AddOutputs(node));
223     return absl::OkStatus();
224   }
225 };
226 
227 class ConcatenationOperationParser : public TFLiteOperationParser {
228  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)229   absl::Status IsSupported(const TfLiteContext* context,
230                            const TfLiteNode* tflite_node,
231                            const TfLiteRegistration* registration) final {
232     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
233 
234     // TODO(eignasheva): add proper tensor availability checking
235     // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
236     //   RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
237     // }
238     // TODO(eignasheva): add axis checking.
239     const TfLiteConcatenationParams* tf_options;
240     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
241     return absl::OkStatus();
242   }
243 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)244   absl::Status Parse(const TfLiteNode* tflite_node,
245                      const TfLiteRegistration* registration,
246                      GraphFloat32* graph, ObjectReader* reader) final {
247     ConcatAttributes attr;
248     // Read inputs first to make sure const node is added to a graph before
249     // concat node to ensure topological order.
250     std::vector<const Value*> inputs;
251     for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
252       Value* value;
253       const auto status = reader->ReadValue(idx, &value);
254       if (status.ok()) {
255         inputs.push_back(value);
256       } else {
257         TensorFloat32 tensor;
258         RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
259         Value* value;
260         RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
261         inputs.push_back(value);
262       }
263     }
264 
265     Node* node = graph->NewNode();
266     node->operation.type = ToString(OperationType::CONCAT);
267     RETURN_IF_ERROR(reader->AddOutputs(node));
268     for (const Value* input : inputs) {
269       RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
270     }
271 
272     std::vector<BHWC> input_shapes;
273     for (auto input : graph->FindInputs(node->id)) {
274       input_shapes.push_back(input->tensor.shape);
275     }
276     RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
277 
278     // Guess axis.
279     BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
280     for (auto input : graph->FindInputs(node->id)) {
281       if (input->tensor.shape.h != output_shape.h) {
282         attr.axis = Axis::HEIGHT;
283         break;
284       }
285       if (input->tensor.shape.w != output_shape.w) {
286         attr.axis = Axis::WIDTH;
287         break;
288       }
289       if (input->tensor.shape.c != output_shape.c) {
290         attr.axis = Axis::CHANNELS;
291         break;
292       }
293     }
294     const TfLiteConcatenationParams* tf_options;
295     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
296     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
297     node->operation.attributes = attr;
298     return absl::OkStatus();
299   }
300 
301  private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)302   absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
303     *axis = Axis::BATCH;
304     for (int i = 1; i < input_shapes.size(); i++) {
305       if (input_shapes[0].h != input_shapes[i].h &&
306           input_shapes[0].w != input_shapes[i].w &&
307           input_shapes[0].c != input_shapes[i].c) {
308         *axis = Axis::HEIGHT;
309         break;
310       }
311     }
312     if (*axis == Axis::BATCH) return absl::OkStatus();
313     for (int i = 1; i < input_shapes.size(); i++) {
314       if (input_shapes[0].b != input_shapes[i].b &&
315           input_shapes[0].w != input_shapes[i].w &&
316           input_shapes[0].c != input_shapes[i].c) {
317         *axis = Axis::WIDTH;
318         break;
319       }
320     }
321     if (*axis == Axis::HEIGHT) return absl::OkStatus();
322     for (int i = 1; i < input_shapes.size(); i++) {
323       if (input_shapes[0].b != input_shapes[i].b &&
324           input_shapes[0].h != input_shapes[i].h &&
325           input_shapes[0].c != input_shapes[i].c) {
326         *axis = Axis::CHANNELS;
327         break;
328       }
329     }
330     if (*axis == Axis::WIDTH) return absl::OkStatus();
331     for (int i = 1; i < input_shapes.size(); i++) {
332       if (input_shapes[0].b != input_shapes[i].b &&
333           input_shapes[0].w != input_shapes[i].w &&
334           input_shapes[0].h != input_shapes[i].h) {
335         return absl::UnimplementedError(
336             "Can concatenate tensors only by batch, height, width, or "
337             "channels.");
338       }
339     }
340     return absl::OkStatus();
341   }
342 };
343 
344 class Conv2DOperationParser : public TFLiteOperationParser {
345  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)346   absl::Status IsSupported(const TfLiteContext* context,
347                            const TfLiteNode* tflite_node,
348                            const TfLiteRegistration* registration) final {
349     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5));
350     const int runtime_inputs =
351         GetNumberOfRuntimeInputsForNode(context, tflite_node);
352     if (runtime_inputs > 2) {
353       return absl::InternalError(
354           absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
355                        runtime_inputs, " runtime inputs."));
356     }
357     const int runtime_outputs = NumOutputs(tflite_node);
358     if (runtime_outputs != 1) {
359       return absl::InternalError(
360           absl::StrCat("Expected 1 output tensor(s), but node has ",
361                        runtime_outputs, " runtime outputs."));
362     }
363     if (runtime_inputs == 1) {
364       RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
365     }
366     const TfLiteConvParams* tf_options;
367     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
368     RETURN_IF_ERROR(CheckStridesAndDilation(
369         tf_options->stride_height, tf_options->stride_width,
370         tf_options->dilation_height_factor, tf_options->dilation_width_factor));
371     return IsActivationSupported(tf_options->activation);
372   }
373 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)374   absl::Status Parse(const TfLiteNode* tflite_node,
375                      const TfLiteRegistration* registration,
376                      GraphFloat32* graph, ObjectReader* reader) final {
377     Node* node = graph->NewNode();
378     node->operation.type = ToString(OperationType::CONVOLUTION_2D);
379     RETURN_IF_ERROR(reader->AddInput(node, 0));
380     RETURN_IF_ERROR(reader->AddOutputs(node));
381 
382     Convolution2DAttributes attr;
383     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
384     if (runtime_inputs == 2) {
385       RETURN_IF_ERROR(reader->AddInput(node, 1));
386     } else {  // runtime_inputs == 1;
387       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
388     }
389     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
390 
391     const TfLiteConvParams* tf_options;
392     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
393     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
394     attr.dilations = HW(tf_options->dilation_height_factor,
395                         tf_options->dilation_width_factor);
396     UpdatePadding(tf_options->padding,
397                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
398     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
399     node->operation.attributes = std::move(attr);
400     return absl::OkStatus();
401   }
402 };
403 
404 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
405  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)406   absl::Status IsSupported(const TfLiteContext* context,
407                            const TfLiteNode* tflite_node,
408                            const TfLiteRegistration* registration) final {
409     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
410     const int runtime_inputs =
411         GetNumberOfRuntimeInputsForNode(context, tflite_node);
412     if (runtime_inputs > 2) {
413       return absl::InternalError(
414           absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
415                        runtime_inputs, " runtime inputs."));
416     }
417     const int runtime_outputs = NumOutputs(tflite_node);
418     if (runtime_outputs != 1) {
419       return absl::InternalError(
420           absl::StrCat("Expected 1 output tensor(s), but node has ",
421                        runtime_outputs, " runtime outputs."));
422     }
423     if (runtime_inputs == 1) {
424       RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
425     }
426     const TfLiteDepthwiseConvParams* tf_options;
427     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
428     RETURN_IF_ERROR(CheckStridesAndDilation(
429         tf_options->stride_height, tf_options->stride_width,
430         tf_options->dilation_height_factor, tf_options->dilation_width_factor));
431     RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
432 
433     const int depth_multiplier = tf_options->depth_multiplier;
434     const auto* input = context->tensors + tflite_node->inputs->data[0];
435     const auto* filter = context->tensors + tflite_node->inputs->data[1];
436     const auto* bias = tflite_node->inputs->size > 2
437                            ? context->tensors + tflite_node->inputs->data[2]
438                            : nullptr;
439     const auto* output = context->tensors + tflite_node->outputs->data[0];
440     if (!input->dims || input->dims->size != 4) {
441       return absl::InvalidArgumentError("input.dims.size != 4");
442     }
443     if (!filter->dims || filter->dims->size != 4) {
444       return absl::InvalidArgumentError("filter.dims.size != 4");
445     }
446     if (!output->dims || output->dims->size != 4) {
447       return absl::InvalidArgumentError("output.dims.size != 4");
448     }
449     if (input->dims->data[0] != output->dims->data[0]) {
450       return absl::InvalidArgumentError("input.b != output.b");
451     }
452     const int input_depth = input->dims->data[3];
453     const int output_depth = output->dims->data[3];
454     if (filter->dims->data[3] != output_depth) {
455       return absl::InvalidArgumentError("filter.i != output.c");
456     }
457     if (output_depth != input_depth * depth_multiplier) {
458       return absl::InvalidArgumentError(
459           "output.c != input.c * depth_multiplier");
460     }
461     if (bias && NumElements(bias) != output_depth) {
462       return absl::InvalidArgumentError("bias.size != output.c");
463     }
464     if (depth_multiplier != 1 && input_depth != 1) {
465       return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1");
466     }
467     return absl::OkStatus();
468   }
469 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)470   absl::Status Parse(const TfLiteNode* tflite_node,
471                      const TfLiteRegistration* registration,
472                      GraphFloat32* graph, ObjectReader* reader) final {
473     Node* node = graph->NewNode();
474     node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
475     RETURN_IF_ERROR(reader->AddInput(node, 0));
476     RETURN_IF_ERROR(reader->AddOutputs(node));
477 
478     DepthwiseConvolution2DAttributes attr;
479     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
480     if (runtime_inputs == 2) {
481       RETURN_IF_ERROR(reader->AddInput(node, 1));
482     } else {  // runtime_inputs == 1;
483       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
484     }
485     reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
486     const TfLiteDepthwiseConvParams* tf_options;
487     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
488     attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
489     attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
490                         std::max(1, tf_options->dilation_width_factor));
491     UpdatePadding(tf_options->padding,
492                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
493     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
494     const int depth_multiplier = tf_options->depth_multiplier;
495     if (depth_multiplier != 1) {
496       const TfLiteTensor* input = reader->GetInputTensor(0);
497       const TfLiteTensor* filter = reader->GetInputTensor(1);
498       const TfLiteTensor* output = reader->GetOutputTensor(0);
499       TransposeWeights(input, filter, output, depth_multiplier, &attr);
500     }
501     node->operation.attributes = std::move(attr);
502     return absl::OkStatus();
503   }
504 
505  private:
506   // TFLite CPU stores weights as:
507   //   [1, kernel_height, kernel_width, input_depth * depth_multiplier]
508   // TFLite GPU stores weights as:
509   //   [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)510   static void TransposeWeights(const TfLiteTensor* input,
511                                const TfLiteTensor* filter,
512                                const TfLiteTensor* output, int depth_multiplier,
513                                DepthwiseConvolution2DAttributes* attr) {
514     const int input_depth = input->dims->data[3];
515     const int filter_height = filter->dims->data[1];
516     const int filter_width = filter->dims->data[2];
517     const int output_depth = output->dims->data[3];
518     Tensor<OHWI, DataType::FLOAT32> weights;
519     weights.id = attr->weights.id;
520     weights.shape =
521         OHWI(output_depth, filter_height, filter_width, input_depth);
522     weights.data.resize(weights.shape.DimensionsProduct());
523     float* dst = &weights.data[0];
524     for (int j = 0; j < output_depth; ++j) {
525       const float* src = attr->weights.data.data() + j;
526       for (int i = 0; i < filter_height * filter_width; ++i) {
527         *dst = *src;
528         dst++;
529         src += output_depth;
530       }
531     }
532     attr->weights = std::move(weights);
533   }
534 };
535 
536 class DequantizeOperationParser : public TFLiteOperationParser {
537  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)538   absl::Status IsSupported(const TfLiteContext* context,
539                            const TfLiteNode* tflite_node,
540                            const TfLiteRegistration* registration) final {
541     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
542     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
543                                        /*runtime_inputs=*/1, /*outputs=*/1));
544     return absl::OkStatus();
545   }
546 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)547   absl::Status Parse(const TfLiteNode* tflite_node,
548                      const TfLiteRegistration* registration,
549                      GraphFloat32* graph, ObjectReader* reader) final {
550     // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
551     // with floating-point versions of the original tensors.
552     Node* node = graph->NewNode();
553     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
554     RETURN_IF_ERROR(reader->AddInput(node, 0));
555     RETURN_IF_ERROR(reader->AddOutputs(node));
556 
557     // Quantization attributes should already be present in the input tensor.
558     auto input_value = graph->FindInputs(node->id)[0];
559     if (!input_value->quant_params) {
560       return absl::InvalidArgumentError(
561           "Encountered Dequantize input with no quant params");
562     }
563     QuantizeAndDequantizeAttributes attr;
564     attr.min = input_value->quant_params.value().min;
565     attr.max = input_value->quant_params.value().max;
566     attr.scale = input_value->quant_params.value().scale;
567 
568     node->operation.attributes = attr;
569     return absl::OkStatus();
570   }
571 };
572 
573 class ElementwiseOperationParser : public TFLiteOperationParser {
574  public:
ElementwiseOperationParser(OperationType operation_type)575   explicit ElementwiseOperationParser(OperationType operation_type)
576       : operation_type_(operation_type) {}
577 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)578   absl::Status IsSupported(const TfLiteContext* context,
579                            const TfLiteNode* tflite_node,
580                            const TfLiteRegistration* registration) final {
581     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
582     if (IsOneArgumentOperation()) {
583       RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
584                                                /*runtime_inputs=*/1,
585                                                /*const_inputs=*/0,
586                                                /*outputs=*/1));
587       // For some elementwise operations (currently only for SUB operation)
588       // second condition may be false. But it's worth checking the next case
589       // with const input, which may be supported.
590     } else if (IsTwoArgumentOperation() &&
591                CheckInputsConstsOutputs(context, tflite_node,
592                                         /*runtime_inputs=*/2,
593                                         /*const_inputs=*/0,
594                                         /*outputs=*/1)
595                    .ok()) {
596     } else if (IsTwoArgumentOperationWithConst()) {
597       RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
598                                                /*runtime_inputs=*/1,
599                                                /*const_inputs=*/1,
600                                                /*outputs=*/1));
601     } else {
602       return absl::InvalidArgumentError(
603           "Op can only handle 1 or 2 operand(s).");
604     }
605     TfLiteFusedActivation activation;
606     RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
607     return IsActivationSupported(activation);
608   }
609 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)610   absl::Status Parse(const TfLiteNode* tflite_node,
611                      const TfLiteRegistration* registration,
612                      GraphFloat32* graph, ObjectReader* reader) final {
613     Node* node = graph->NewNode();
614     node->operation.type = ToString(operation_type_);
615 
616     if (IsOneArgumentOperation()) {
617       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
618                                                         /*runtime_inputs=*/1,
619                                                         /*const_inputs=*/0,
620                                                         /*outputs=*/1));
621 
622       RETURN_IF_ERROR(reader->AddInput(node, 0));
623     } else if (IsTwoArgumentOperation() &&
624                reader
625                    ->VerifyInputsConstsOutputs(tflite_node,
626                                                /*runtime_inputs=*/2,
627                                                /*const_inputs=*/0,
628                                                /*outputs=*/1)
629                    .ok()) {
630       if (tflite_node->inputs->size != 2) {
631         return absl::InvalidArgumentError("Applies only two input tensors");
632       }
633       RETURN_IF_ERROR(reader->AddInput(node, 0));
634       RETURN_IF_ERROR(reader->AddInput(node, 1));
635 
636       TfLiteFusedActivation activation = kTfLiteActNone;
637       switch (operation_type_) {
638         case OperationType::SUB: {
639           const TfLiteSubParams* tf_options;
640           if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
641             activation = tf_options->activation;
642           }
643           break;
644         }
645         case OperationType::DIV: {
646           const TfLiteDivParams* tf_options;
647           if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
648             activation = tf_options->activation;
649           }
650           break;
651         }
652         default:
653           // No activation expected.
654           activation = kTfLiteActNone;
655       }
656 
657       if (activation) {
658         RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
659       }
660     } else if (IsTwoArgumentOperationWithConst()) {
661       RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
662                                                         /*runtime_inputs=*/1,
663                                                         /*const_inputs=*/1,
664                                                         /*outputs=*/1));
665       ElementwiseAttributes attr;
666       RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
667       attr.runtime_tensor_is_second =
668           IsConstantTensor(reader->GetInputTensor(0));
669       node->operation.attributes = std::move(attr);
670     } else {
671       return absl::InvalidArgumentError("Incorrect operation type passed");
672     }
673 
674     return reader->AddOutputs(node);
675   }
676 
677  private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const678   absl::Status GetActivation(const TfLiteNode* tflite_node,
679                              TfLiteFusedActivation* activation) const {
680     if (operation_type_ == OperationType::DIV) {
681       const TfLiteDivParams* tf_options;
682       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
683       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
684       return absl::OkStatus();
685     }
686     if (operation_type_ == OperationType::SUB) {
687       const TfLiteSubParams* tf_options;
688       auto status = RetrieveBuiltinData(tflite_node, &tf_options);
689       *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
690       return absl::OkStatus();
691     }
692 
693     // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
694     // TfLiteXxxParams.activation.
695     *activation = kTfLiteActNone;
696     return absl::OkStatus();
697   }
698 
IsOneArgumentOperation() const699   bool IsOneArgumentOperation() const {
700     switch (operation_type_) {
701       case OperationType::ABS:
702       case OperationType::COPY:
703       case OperationType::COS:
704       case OperationType::ELU:
705       case OperationType::EXP:
706       case OperationType::LOG:
707       case OperationType::NEG:
708       case OperationType::RSQRT:
709       case OperationType::SIGMOID:
710       case OperationType::SIN:
711       case OperationType::SQRT:
712       case OperationType::SQUARE:
713       case OperationType::TANH:
714         return true;
715       default:
716         return false;
717     }
718   }
719 
IsTwoArgumentOperation() const720   bool IsTwoArgumentOperation() const {
721     switch (operation_type_) {
722       case OperationType::DIV:
723       case OperationType::MAXIMUM:
724       case OperationType::MINIMUM:
725       case OperationType::POW:
726       case OperationType::SQUARED_DIFF:
727       case OperationType::SUB:
728         return true;
729       default:
730         return false;
731     }
732   }
733 
IsTwoArgumentOperationWithConst() const734   bool IsTwoArgumentOperationWithConst() const {
735     switch (operation_type_) {
736       case OperationType::DIV:
737       case OperationType::MAXIMUM:
738       case OperationType::MINIMUM:
739       case OperationType::POW:
740       case OperationType::SQUARED_DIFF:
741       case OperationType::SUB:
742         return true;
743       default:
744         return false;
745     }
746   }
747 
748   OperationType operation_type_;
749 };
750 
751 class FullyConnectedOperationParser : public TFLiteOperationParser {
752  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)753   absl::Status IsSupported(const TfLiteContext* context,
754                            const TfLiteNode* tflite_node,
755                            const TfLiteRegistration* registration) final {
756     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
757     const TfLiteFullyConnectedParams* tf_options;
758     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
759     if (tf_options->weights_format !=
760         kTfLiteFullyConnectedWeightsFormatDefault) {
761       return absl::UnimplementedError(
762           "Unsupported FullyConnected weights format.");
763     }
764     if (GetNumberOfRuntimeInputsForNode(context, tflite_node) > 2) {
765       return absl::UnimplementedError(
766           "FullyConnected doesn't support more than 2 runtime inputs.");
767     }
768     if (tf_options->keep_num_dims == true) {
769       const auto* input = context->tensors + tflite_node->inputs->data[0];
770       const auto* output = context->tensors + tflite_node->outputs->data[0];
771       if (input->dims->size != output->dims->size) {
772         return absl::UnimplementedError(
773             "Input and output dimensions different and FullyConnected doesn't "
774             "support keep_num_dims.");
775       }
776     }
777     // TODO(eignasheva): check input shape
778     return absl::OkStatus();
779   }
780 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)781   absl::Status Parse(const TfLiteNode* tflite_node,
782                      const TfLiteRegistration* registration,
783                      GraphFloat32* graph, ObjectReader* reader) final {
784     const TfLiteFullyConnectedParams* tf_options;
785     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
786 
787     if (reader->GetNumberOfRuntimeInputs() == 2) {
788       // Create Convolution2D, so as it supports runtime weights.
789       Node* node = graph->NewNode();
790       node->operation.type = ToString(OperationType::CONVOLUTION_2D);
791       RETURN_IF_ERROR(reader->AddInput(node, 0));
792       RETURN_IF_ERROR(reader->AddInput(node, 1));
793       RETURN_IF_ERROR(reader->AddOutputs(node));
794 
795       Convolution2DAttributes attr;
796       reader->ReadTensor(2, &attr.bias).IgnoreError();  // bias is optional
797 
798       attr.strides = HW(1, 1);
799       attr.dilations = HW(1, 1);
800       attr.padding.appended = HW(0, 0);
801       attr.padding.prepended = HW(0, 0);
802       RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
803       node->operation.attributes = std::move(attr);
804       return absl::OkStatus();
805     }
806     Node* node = graph->NewNode();
807     RETURN_IF_ERROR(reader->AddInput(node, 0));
808 
809     if (tf_options->weights_format !=
810         kTfLiteFullyConnectedWeightsFormatDefault) {
811       return absl::UnimplementedError(
812           "Unsupported FullyConnected weights format.");
813     }
814 
815     FullyConnectedAttributes attr;
816     RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
817     const int weights_width = attr.weights.shape.i;
818 
819     auto input = graph->FindInputs(node->id)[0];
820     int batch_size = input->tensor.shape.b;
821     if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) {
822       return absl::UnimplementedError(
823           "Amount of input data should match weights width");
824     }
825 
826     Node* conv = node;
827     if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
828       auto& reshape = node;
829       conv = graph->NewNode();  // reset conv pointer!
830       Value* reshaped_value = graph->NewValue();
831       reshaped_value->tensor.type = DataType::FLOAT32;
832       reshaped_value->tensor.shape =
833           BHWC(input->tensor.shape.b, 1, 1, weights_width);
834       RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
835       reshape->operation.type = ToString(OperationType::RESHAPE);
836       ReshapeAttributes attr;
837       attr.new_shape = reshaped_value->tensor.shape;
838       reshape->operation.attributes = attr;
839       RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
840     }
841 
842     conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
843     conv->operation.attributes = std::move(attr);
844     absl::Status result = reader->AddOutputs(conv);
845     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
846 
847     return result;
848   }
849 };
850 
851 class HardSwishOperationParser : public TFLiteOperationParser {
852  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration *)853   absl::Status IsSupported(const TfLiteContext* context,
854                            const TfLiteNode* tflite_node,
855                            const TfLiteRegistration*) final {
856     return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
857                               /*outputs=*/1);
858   }
859 
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)860   absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
861                      GraphFloat32* graph, ObjectReader* reader) final {
862     Node* node = graph->NewNode();
863     node->operation.type = ToString(OperationType::HARD_SWISH);
864     RETURN_IF_ERROR(reader->AddInput(node, 0));
865     return reader->AddOutputs(node);
866   }
867 };
868 
869 // Basic LSTM Cell:
870 //
871 //  1name = name is at input  index 1
872 //  name1 = name is at output index 1
873 //
874 //    0input     1prev_activ
875 //       \        /
876 //        [[concat]]
877 //             \
878 //       concat_temp2  2weights  3biases
879 //              \      /        /
880 //             [[fully-connected]]
881 //               \
882 //         activ_temp3    4prev_state
883 //                 \      /
884 //                 [[LSTM]]
885 //                 /      \
886 //           new_state1    activation0
887 //
888 // For full LSTM cells, see this blog post:
889 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
890 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
891 // described in that post, this code also adds the following optional features:
892 // - Configurable activations (sigmoid or TANH)
893 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
894 // - Output projection:
895 //     https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
896 // - Configurable clipping of cell state and output state.
897 class LSTMOperationParser : public TFLiteOperationParser {
898  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)899   absl::Status IsSupported(const TfLiteContext* context,
900                            const TfLiteNode* tflite_node,
901                            const TfLiteRegistration* registration) final {
902     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
903     const TfLiteLSTMParams* tf_options;
904     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
905     switch (tf_options->kernel_type) {
906       case kTfLiteLSTMFullKernel: {
907         const int inputs = NumInputs(tflite_node);
908         if (inputs != 20 && inputs != 24) {
909           return absl::InternalError(
910               absl::StrCat("Expected 20 or 24 input tensors, but node has ",
911                            inputs, " input(s)."));
912         }
913         const int runtime_outputs = NumOutputs(tflite_node);
914         if (runtime_outputs != 1) {
915           return absl::InternalError(
916               absl::StrCat("Expected 1 output tensor, but node has ",
917                            runtime_outputs, " output(s)."));
918         }
919         return CheckFullParameters(tf_options);
920       }
921       case kTfLiteLSTMBasicKernel:
922         RETURN_IF_ERROR(
923             CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3,
924                                      /*const_inputs=*/2, /*outputs=*/4));
925         return CheckBasicParameters(tf_options);
926     }
927   }
928 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)929   absl::Status Parse(const TfLiteNode* tflite_node,
930                      const TfLiteRegistration* registration,
931                      GraphFloat32* graph, ObjectReader* reader) final {
932     const TfLiteLSTMParams* tf_options;
933     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
934     switch (tf_options->kernel_type) {
935       case kTfLiteLSTMFullKernel:
936         return ParseFull(tflite_node, registration, graph, reader, tf_options);
937       case kTfLiteLSTMBasicKernel:
938         return ParseBasic(tflite_node, registration, graph, reader, tf_options);
939     }
940   }
941 
GetNewValueIdsForVariableInputNodes()942   absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
943       final {
944     return new_variable_input_value_map_;
945   }
946 
947  private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)948   absl::Status ParseBasic(const TfLiteNode* tflite_node,
949                           const TfLiteRegistration* registration,
950                           GraphFloat32* graph, ObjectReader* reader,
951                           const TfLiteLSTMParams* tf_options) {
952     if (tflite_node->inputs->size != 5) {
953       return absl::InvalidArgumentError("LSTM should have 5 input tensors");
954     }
955     if (tflite_node->outputs->size != 4) {
956       return absl::InvalidArgumentError("LSTM should have 4 output tensors");
957     }
958     RETURN_IF_ERROR(CheckBasicParameters(tf_options));
959 
960     Node* concat_node = graph->NewNode();
961     concat_node->operation.type = ToString(OperationType::CONCAT);
962     ConcatAttributes concat_attr;
963     concat_attr.axis = Axis::CHANNELS;
964     concat_node->operation.attributes = concat_attr;
965 
966     Node* fc_node = graph->NewNode();
967     fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
968     FullyConnectedAttributes fc_attr;
969     RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
970     fc_node->operation.attributes = std::move(fc_attr);
971 
972     Node* lstm_node = graph->NewNode();
973     lstm_node->operation.type = ToString(OperationType::LSTM);
974     LstmAttributes lstm_attr;
975     lstm_attr.kernel_type = LstmKernelType::BASIC;
976     lstm_node->operation.attributes = lstm_attr;
977 
978     Value* concat_temp;
979     int concat_tensor_idx = tflite_node->outputs->data[2];
980     RETURN_IF_ERROR(
981         reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
982     Value* activ_temp;
983     int activ_tensor_idx = tflite_node->outputs->data[3];
984     RETURN_IF_ERROR(
985         reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
986 
987     RETURN_IF_ERROR(reader->AddInput(concat_node, 0));  // input
988     RETURN_IF_ERROR(reader->AddInput(concat_node, 1));  // prev_activ
989     RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
990 
991     RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
992     RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
993 
994     RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
995     RETURN_IF_ERROR(reader->AddInput(lstm_node, 4));   // prev_state
996     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1));  // new_state
997     RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0));  // activation
998 
999     return absl::OkStatus();
1000   }
1001 
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1002   absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1003     if (tf_options->activation != kTfLiteActTanh) {
1004       return absl::UnimplementedError("Only TANH activation is supported.");
1005     }
1006     if (tf_options->cell_clip != 0.0f) {
1007       return absl::UnimplementedError("cell_clip is not supported.");
1008     }
1009     if (tf_options->proj_clip != 0.0f) {
1010       return absl::UnimplementedError("proj_clip is not supported.");
1011     }
1012     return absl::OkStatus();
1013   }
1014 
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1015   absl::Status ParseFull(const TfLiteNode* tflite_node,
1016                          const TfLiteRegistration* registration,
1017                          GraphFloat32* graph, ObjectReader* reader,
1018                          const TfLiteLSTMParams* tf_options) {
1019     // Invoke full LSTM parser
1020     RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1021                                         reader, tf_options,
1022                                         &new_variable_input_value_map_));
1023     return absl::OkStatus();
1024   }
1025 
CheckFullParameters(const TfLiteLSTMParams * tf_options)1026   absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1027     if (tf_options->activation != kTfLiteActSigmoid &&
1028         tf_options->activation != kTfLiteActTanh) {
1029       return absl::UnimplementedError(
1030           "Only sigmoid or tanh activation is supported.");
1031     }
1032 
1033     return absl::OkStatus();
1034   }
1035 
1036   absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1037 };
1038 
1039 class MulOperationParser : public TFLiteOperationParser {
1040  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1041   absl::Status IsSupported(const TfLiteContext* context,
1042                            const TfLiteNode* tflite_node,
1043                            const TfLiteRegistration* registration) final {
1044     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1045     if (tflite_node->inputs->size != 2) {
1046       return absl::UnimplementedError("MUL requires two input tensors.");
1047     }
1048     const TfLiteTensor* input0 = GetInput(context, tflite_node, 0);
1049     const TfLiteTensor* input1 = GetInput(context, tflite_node, 1);
1050     if (input0 == nullptr || input1 == nullptr) {
1051       return absl::InvalidArgumentError("At least one input tensor is null");
1052     }
1053     if (input0->dims->size == input1->dims->size) {
1054       // this code checks that at least one input of Mul not smaller in all
1055       // dimensions. Sometimes Mul used for matrix-vector multiplication that we
1056       // currently don't support. For example input0 HWC(1, 256, 1), input1
1057       // HWC(1, 1, 256) -> output HWC (1, 256, 256). In this case it can be
1058       // replaced with Convolution operation.
1059       bool first_has_smaller_dim = false;
1060       bool second_has_smaller_dim = false;
1061       for (int i = 0; i < input0->dims->size; ++i) {
1062         if (input0->dims->data[i] < input1->dims->data[i]) {
1063           first_has_smaller_dim = true;
1064         }
1065         if (input1->dims->data[i] < input0->dims->data[i]) {
1066           second_has_smaller_dim = true;
1067         }
1068       }
1069       if (first_has_smaller_dim && second_has_smaller_dim) {
1070         return absl::UnimplementedError(
1071             "MUL requires one tensor that not less than second in all "
1072             "dimensions.");
1073       }
1074     }
1075     const TfLiteMulParams* tf_options;
1076     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1077     return IsActivationSupported(tf_options->activation);
1078   }
1079 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1080   absl::Status Parse(const TfLiteNode* tflite_node,
1081                      const TfLiteRegistration* registration,
1082                      GraphFloat32* graph, ObjectReader* reader) final {
1083     const TfLiteTensor* input0 = reader->GetInputTensor(0);
1084     if (!input0) {
1085       return absl::InvalidArgumentError(
1086           "Couldn't get the 1st input tensor for MUL.");
1087     }
1088     const TfLiteTensor* input1 = reader->GetInputTensor(1);
1089     if (!input1) {
1090       return absl::InvalidArgumentError(
1091           "Couldn't get the 2nd input tensor for MUL.");
1092     }
1093     const bool constant_tensor0 = IsConstantTensor(input0);
1094     const bool constant_tensor1 = IsConstantTensor(input1);
1095     if (constant_tensor0 && constant_tensor1) {
1096       return absl::InvalidArgumentError("No runtime input tensors for MUL.");
1097     }
1098     const bool runtime_tensor0 = !constant_tensor0;
1099     const bool runtime_tensor1 = !constant_tensor1;
1100 
1101     Node* node = graph->NewNode();
1102     node->operation.type = ToString(OperationType::MUL);
1103     RETURN_IF_ERROR(reader->AddOutputs(node));
1104 
1105     // Determine runtime/constant tensors.
1106     if (runtime_tensor0 && runtime_tensor1) {
1107       if (input0 == input1) {
1108         // replace MUL(A, A) with POW(A, 2.0)
1109         // TODO(b/166831113): Support the same inputs for operations.
1110         node->operation.type = ToString(OperationType::POW);
1111         ElementwiseAttributes attr;
1112         attr.param = 2.0f;
1113         node->operation.attributes = std::move(attr);
1114         return reader->AddInput(node, 0);
1115       }
1116 
1117       // The "larger" input tensor must be bound to 1st input and the "smaller"
1118       // input tensor must be bound to 2nd input.
1119       BHWC shape0;
1120       RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1121       BHWC shape1;
1122       RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1123       int input_tensor0 = 0;
1124       int input_tensor1 = 1;
1125       if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1126           shape0.c == shape1.c) {
1127         input_tensor0 = 1;
1128         input_tensor1 = 0;
1129       }
1130       RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1131       RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1132     } else {
1133       ElementwiseAttributes attr;
1134       RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1135       node->operation.attributes = std::move(attr);
1136     }
1137 
1138     const TfLiteMulParams* tf_options;
1139     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1140     return MaybeFuseActivation(tf_options->activation, graph, node);
1141   }
1142 };
1143 
1144 class PackOperationParser : public TFLiteOperationParser {
1145  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1146   absl::Status IsSupported(const TfLiteContext* context,
1147                            const TfLiteNode* tflite_node,
1148                            const TfLiteRegistration* registration) final {
1149     const TfLitePackParams* tf_options;
1150     return RetrieveBuiltinData(tflite_node, &tf_options);
1151   }
1152 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1153   absl::Status Parse(const TfLiteNode* tflite_node,
1154                      const TfLiteRegistration* registration,
1155                      GraphFloat32* graph, ObjectReader* reader) final {
1156     if (tflite_node->inputs->size == 1) {
1157       // Pack with single input can be replaced with Reshape
1158       Node* node = graph->NewNode();
1159       node->operation.type = ToString(OperationType::RESHAPE);
1160       RETURN_IF_ERROR(reader->AddInput(node, 0));
1161       RETURN_IF_ERROR(reader->AddOutputs(node));
1162       // New shape comes from output shape.
1163       ReshapeAttributes attr;
1164       attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1165       node->operation.attributes = attr;
1166       return absl::OkStatus();
1167     } else {
1168       // Pack with few inputs can be replaced with Concat
1169       const TfLitePackParams* tf_options;
1170       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1171 
1172       // Read inputs first to make sure const node is added to a graph before
1173       // concat node to ensure topological order.
1174       std::vector<const Value*> inputs;
1175       for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1176         Value* value;
1177         const auto status = reader->ReadValue(idx, &value);
1178         if (status.ok()) {
1179           inputs.push_back(value);
1180         } else {
1181           TensorFloat32 tensor;
1182           RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1183           Value* value;
1184           RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1185           inputs.push_back(value);
1186         }
1187       }
1188 
1189       Node* node = graph->NewNode();
1190       node->operation.type = ToString(OperationType::CONCAT);
1191       RETURN_IF_ERROR(reader->AddOutputs(node));
1192       for (const Value* input : inputs) {
1193         RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1194       }
1195       const TfLiteTensor* output = reader->GetOutputTensor(0);
1196       ConcatAttributes attr;
1197       RETURN_IF_ERROR(
1198           ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1199       node->operation.attributes = attr;
1200       return absl::OkStatus();
1201     }
1202   }
1203 };
1204 
1205 class PReLUOperationParser : public TFLiteOperationParser {
1206  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1207   absl::Status IsSupported(const TfLiteContext* context,
1208                            const TfLiteNode* tflite_node,
1209                            const TfLiteRegistration* registration) final {
1210     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1211     // TODO(eignasheva): add params check
1212     return absl::OkStatus();
1213   }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1214   absl::Status Parse(const TfLiteNode* tflite_node,
1215                      const TfLiteRegistration* registration,
1216                      GraphFloat32* graph, ObjectReader* reader) final {
1217     Node* node = graph->NewNode();
1218     node->operation.type = ToString(OperationType::PRELU);
1219     RETURN_IF_ERROR(reader->AddInput(node, 0));
1220     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1221 
1222     PReLUAttributes attr;
1223     Tensor<Linear, DataType::FLOAT32> linear_alpha;
1224     absl::Status status = reader->ReadTensor(1, &linear_alpha);
1225     if (status.ok()) {
1226       if (linear_alpha.shape.v != input_shape.c) {
1227         return absl::InvalidArgumentError(
1228             "Linear alpha shape does not match the number of input channels.");
1229       }
1230       attr.alpha = std::move(linear_alpha);
1231     } else {
1232       Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1233       RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1234       if (hwc_alpha.shape.h != input_shape.h ||
1235           hwc_alpha.shape.w != input_shape.w ||
1236           hwc_alpha.shape.c != input_shape.c) {
1237         return absl::InvalidArgumentError(
1238             "Alpha shape does not match input shape.");
1239       }
1240       attr.alpha = std::move(hwc_alpha);
1241     }
1242     node->operation.attributes = std::move(attr);
1243     return reader->AddOutputs(node);
1244   }
1245 };
1246 
1247 class PadOperationParser : public TFLiteOperationParser {
1248  public:
PadOperationParser(bool mirror_pad)1249   explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1250 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1251   absl::Status IsSupported(const TfLiteContext* context,
1252                            const TfLiteNode* tflite_node,
1253                            const TfLiteRegistration* registration) final {
1254     if (mirror_pad_) {
1255       const TfLiteMirrorPaddingParams* tf_options;
1256       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1257       if (tf_options->mode !=
1258           TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
1259         return absl::InvalidArgumentError(
1260             "Only Reflective padding is supported for Mirror Pad operation.");
1261       }
1262     }
1263     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1264     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1265                                        /*runtime_inputs=*/1, /*outputs=*/1));
1266     RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
1267     const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1);
1268     if (pad_tensor == nullptr) {
1269       return absl::InvalidArgumentError("Padding tensor was null");
1270     }
1271     if (pad_tensor->dims->size != 2) {
1272       return absl::InvalidArgumentError(absl::StrCat(
1273           "Invalid paddings tensor dimension: expected 2 dim, got ",
1274           pad_tensor->dims->size, " dim"));
1275     }
1276     bool supported =
1277         pad_tensor->dims->data[0] == 3 || pad_tensor->dims->data[0] == 4;
1278     if (!supported || pad_tensor->dims->data[1] != 2) {
1279       return absl::InvalidArgumentError(absl::StrCat(
1280           "Invalid paddings tensor shape: expected 4x2 or 3x2, got ",
1281           pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
1282     }
1283     return absl::OkStatus();
1284   }
1285 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1286   absl::Status Parse(const TfLiteNode* tflite_node,
1287                      const TfLiteRegistration* registration,
1288                      GraphFloat32* graph, ObjectReader* reader) final {
1289     Node* node = graph->NewNode();
1290     node->operation.type = ToString(OperationType::PAD);
1291     RETURN_IF_ERROR(reader->AddInput(node, 0));
1292     RETURN_IF_ERROR(reader->AddOutputs(node));
1293 
1294     PadAttributes attr;
1295     if (mirror_pad_) {
1296       attr.type = PaddingContentType::REFLECT;
1297     } else /*zero pad*/ {
1298       attr.type = PaddingContentType::ZEROS;
1299     }
1300 
1301     Tensor<HW, DataType::INT32> paddings;
1302     RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1303 
1304     if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1305       // 4x2 tensor with paddings.
1306       attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1307                             paddings.data[4], paddings.data[6]);
1308       attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1309                            paddings.data[7]);
1310     } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1311       // 3x2 tensor with paddings.
1312       attr.prepended =
1313           BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1314       attr.appended =
1315           BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1316     } else {
1317       // It shouldn't fail here since it's checked at IsSupported().
1318       return absl::InvalidArgumentError(
1319           "Paddings tensor has unexpected shape.");
1320     }
1321     node->operation.attributes = attr;
1322     return absl::OkStatus();
1323   }
1324 
1325  private:
1326   bool mirror_pad_ = false;
1327 };
1328 
1329 class Pooling2DOperationParser : public TFLiteOperationParser {
1330  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1331   absl::Status IsSupported(const TfLiteContext* context,
1332                            const TfLiteNode* tflite_node,
1333                            const TfLiteRegistration* registration) final {
1334     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1335     const TfLitePoolParams* tf_options;
1336     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1337     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1338                                        /*runtime_inputs=*/1,
1339                                        /*outputs=*/1));
1340     RETURN_IF_ERROR(CheckKernelsAndStrides(
1341         tf_options->filter_height, tf_options->filter_width,
1342         tf_options->stride_height, tf_options->stride_width));
1343     return IsActivationSupported(tf_options->activation);
1344   }
1345 
1346  public:
Pooling2DOperationParser(PoolingType type)1347   explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1348 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1349   absl::Status Parse(const TfLiteNode* tflite_node,
1350                      const TfLiteRegistration* registration,
1351                      GraphFloat32* graph, ObjectReader* reader) final {
1352     Node* node = graph->NewNode();
1353     node->operation.type = ToString(OperationType::POOLING_2D);
1354     RETURN_IF_ERROR(reader->AddInput(node, 0));
1355     RETURN_IF_ERROR(reader->AddOutput(node, 0));
1356 
1357     Pooling2DAttributes attr;
1358     attr.type = type_;
1359 
1360     auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1361 
1362     const TfLitePoolParams* tf_options;
1363     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1364 
1365     RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1366 
1367     attr.output_indices = false;
1368     RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1369     node->operation.attributes = attr;
1370     return absl::OkStatus();
1371   }
1372 
1373  private:
1374   const PoolingType type_;
1375 };
1376 
1377 class ReduceOperationParser : public TFLiteOperationParser {
1378  public:
ReduceOperationParser(OperationType operation_type)1379   explicit ReduceOperationParser(OperationType operation_type)
1380       : operation_type_(operation_type) {}
1381 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1382   absl::Status IsSupported(const TfLiteContext* context,
1383                            const TfLiteNode* tflite_node,
1384                            const TfLiteRegistration* registration) final {
1385     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1386                                        /*runtime_inputs=*/1, /*outputs=*/1));
1387     auto* axes = &context->tensors[tflite_node->inputs->data[1]];
1388     if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
1389       return absl::UnimplementedError(
1390           "Reduce has unsupported tensor for axes.");
1391     }
1392     return absl::OkStatus();
1393   }
1394 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1395   absl::Status Parse(const TfLiteNode* tflite_node,
1396                      const TfLiteRegistration* registration,
1397                      GraphFloat32* graph, ObjectReader* reader) final {
1398     Node* node = graph->NewNode();
1399     node->operation.type = ToString(operation_type_);
1400     RETURN_IF_ERROR(reader->AddInput(node, 0));
1401     RETURN_IF_ERROR(reader->AddOutputs(node));
1402 
1403     const TfLiteReducerParams* tf_options;
1404     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1405 
1406     ReduceAttributes attr;
1407     const TfLiteTensor* input = reader->GetInputTensor(0);
1408     const TfLiteTensor* axes = reader->GetInputTensor(1);
1409     for (int i = 0; i < NumElements(axes->dims); i++) {
1410       Axis axis;
1411       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1412       attr.dims.insert(axis);
1413     }
1414     node->operation.attributes = attr;
1415     return absl::OkStatus();
1416   }
1417 
1418  private:
1419   const OperationType operation_type_;
1420 };
1421 
1422 class QuantizeOperationParser : public TFLiteOperationParser {
1423  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1424   absl::Status IsSupported(const TfLiteContext* context,
1425                            const TfLiteNode* tflite_node,
1426                            const TfLiteRegistration* registration) final {
1427     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1428     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1429                                        /*runtime_inputs=*/1, /*outputs=*/1));
1430     return absl::OkStatus();
1431   }
1432 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1433   absl::Status Parse(const TfLiteNode* tflite_node,
1434                      const TfLiteRegistration* registration,
1435                      GraphFloat32* graph, ObjectReader* reader) final {
1436     // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1437     // with floating-point versions of the original tensors.
1438     Node* node = graph->NewNode();
1439     node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1440     RETURN_IF_ERROR(reader->AddInput(node, 0));
1441     RETURN_IF_ERROR(reader->AddOutputs(node));
1442 
1443     // Quantization attributes should already be present in the output tensor.
1444     auto output_value = graph->FindOutputs(node->id)[0];
1445     if (!output_value->quant_params) {
1446       return absl::InvalidArgumentError(
1447           "Encountered Quantize output with no quant params");
1448     }
1449     QuantizeAndDequantizeAttributes attr;
1450     attr.min = output_value->quant_params.value().min;
1451     attr.max = output_value->quant_params.value().max;
1452     attr.scale = output_value->quant_params.value().scale;
1453 
1454     node->operation.attributes = attr;
1455     return absl::OkStatus();
1456   }
1457 };
1458 
1459 class ReLUOperationParser : public TFLiteOperationParser {
1460  public:
ReLUOperationParser(int clip)1461   explicit ReLUOperationParser(int clip) : clip_(clip) {}
1462 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1463   absl::Status IsSupported(const TfLiteContext* context,
1464                            const TfLiteNode* tflite_node,
1465                            const TfLiteRegistration* registration) final {
1466     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1467     return absl::OkStatus();
1468   }
1469 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1470   absl::Status Parse(const TfLiteNode* tflite_node,
1471                      const TfLiteRegistration* registration,
1472                      GraphFloat32* graph, ObjectReader* reader) final {
1473     Node* node = graph->NewNode();
1474     node->operation.type = ToString(OperationType::RELU);
1475     RETURN_IF_ERROR(reader->AddInput(node, 0));
1476 
1477     ReLUAttributes attr;
1478     const TfLiteLeakyReluParams* tf_options;
1479     auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1480     attr.alpha = status.ok() ? tf_options->alpha : 0;
1481     attr.clip = clip_;
1482     node->operation.attributes = attr;
1483     return reader->AddOutputs(node);
1484   }
1485 
1486  private:
1487   const int clip_;
1488 };
1489 
1490 class ReshapeOperationParser : public TFLiteOperationParser {
1491  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1492   absl::Status IsSupported(const TfLiteContext* context,
1493                            const TfLiteNode* tflite_node,
1494                            const TfLiteRegistration* registration) final {
1495     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1496     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1497                                        /*runtime_inputs=*/1, /*outputs=*/1));
1498     // TODO(eignasheva): add shape checking
1499     return absl::OkStatus();
1500   }
1501 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1502   absl::Status Parse(const TfLiteNode* tflite_node,
1503                      const TfLiteRegistration* registration,
1504                      GraphFloat32* graph, ObjectReader* reader) final {
1505     Node* node = graph->NewNode();
1506     node->operation.type = ToString(OperationType::RESHAPE);
1507     RETURN_IF_ERROR(reader->AddInput(node, 0));
1508     RETURN_IF_ERROR(reader->AddOutputs(node));
1509     // Here we may have extra inputs. Other tensors were supposed to
1510     // define new shape, but in TFLite these are ignored.
1511     // TODO(akulik): check that shapes match?
1512 
1513     // New shape comes from output shape.
1514     ReshapeAttributes attr;
1515     attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1516     node->operation.attributes = attr;
1517     return absl::OkStatus();
1518   }
1519 };
1520 
1521 class Resize2DOperationParser : public TFLiteOperationParser {
1522  public:
Resize2DOperationParser(SamplingType sampling_type)1523   explicit Resize2DOperationParser(SamplingType sampling_type)
1524       : sampling_type_(sampling_type) {}
1525 
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1526   absl::Status IsSupported(const TfLiteContext* context,
1527                            const TfLiteNode* tflite_node,
1528                            const TfLiteRegistration* registration) final {
1529     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1530     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1531                                        /*runtime_inputs=*/1, /*outputs=*/1));
1532 
1533     bool align_corners;
1534     RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
1535     bool half_pixel_centers;
1536     RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers));
1537     return absl::OkStatus();
1538   }
1539 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1540   absl::Status Parse(const TfLiteNode* tflite_node,
1541                      const TfLiteRegistration* registration,
1542                      GraphFloat32* graph, ObjectReader* reader) final {
1543     Node* node = graph->NewNode();
1544     node->operation.type = ToString(OperationType::RESIZE);
1545     RETURN_IF_ERROR(reader->AddInput(node, 0));
1546     RETURN_IF_ERROR(reader->AddOutputs(node));
1547     // Here we may have extra inputs. Other tensors were supposed to
1548     // define new shape, but in TFLite these are ignored.
1549 
1550     Resize2DAttributes attr;
1551     RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1552     RETURN_IF_ERROR(
1553         GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1554     attr.type = sampling_type_;
1555     attr.new_shape.CopyAllDefinedAxis(
1556         graph->FindOutputs(node->id)[0]->tensor.shape);
1557     node->operation.attributes = attr;
1558     return absl::OkStatus();
1559   }
1560 
1561  private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1562   absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1563                                     bool* align_corners) {
1564     switch (sampling_type_) {
1565       case SamplingType::BILINEAR:
1566         return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1567             tflite_node, align_corners);
1568       case SamplingType::NEAREST:
1569         return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1570             tflite_node, align_corners);
1571       case SamplingType::UNKNOWN:
1572         return absl::InternalError("Sampling type is not specified");
1573     }
1574     return absl::OkStatus();
1575   }
1576 
1577   template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1578   absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1579                                            bool* align_corners) {
1580     const T* tf_options;
1581     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1582     *align_corners = tf_options->align_corners;
1583     return absl::OkStatus();
1584   }
1585 
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1586   absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1587                                         bool* half_pixel_centers) {
1588     if (sampling_type_ == SamplingType::BILINEAR) {
1589       const TfLiteResizeBilinearParams* tf_options;
1590       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1591       if (tf_options->align_corners && tf_options->half_pixel_centers) {
1592         return absl::InternalError(
1593             "If half_pixel_centers is True, align_corners must be False.");
1594       }
1595       *half_pixel_centers = tf_options->half_pixel_centers;
1596     } else {
1597       const TfLiteResizeNearestNeighborParams* tf_options;
1598       RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1599       *half_pixel_centers = tf_options->half_pixel_centers;
1600     }
1601     return absl::OkStatus();
1602   }
1603 
1604   SamplingType sampling_type_ = SamplingType::UNKNOWN;
1605 };
1606 
1607 class SliceOperationParser : public TFLiteOperationParser {
1608  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1609   absl::Status IsSupported(const TfLiteContext* context,
1610                            const TfLiteNode* tflite_node,
1611                            const TfLiteRegistration* registration) final {
1612     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1613     if (tflite_node->inputs->size < 3) {
1614       return absl::UnimplementedError("SLICE requires 3 inputs.");
1615     }
1616     const TfLiteTensor* input = GetInput(context, tflite_node, 0);
1617     if (input->dims->size != 3 && input->dims->size != 4) {
1618       return absl::UnimplementedError(
1619           "SLICE supports for 3 or 4 dimensional tensors only.");
1620     }
1621 
1622     return absl::OkStatus();
1623   }
1624 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1625   absl::Status Parse(const TfLiteNode* tflite_node,
1626                      const TfLiteRegistration* registration,
1627                      GraphFloat32* graph, ObjectReader* reader) final {
1628     Node* node = graph->NewNode();
1629     node->operation.type = ToString(OperationType::SLICE);
1630     RETURN_IF_ERROR(reader->AddOutputs(node));
1631     Value* input;
1632     RETURN_IF_ERROR(reader->ReadValue(0, &input));
1633     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1634 
1635     const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
1636     const int input_dims = tfl_input->dims->size;
1637 
1638     SliceAttributes attr;
1639     attr.strides = BHWC(1, 1, 1, 1);
1640     Tensor<Linear, DataType::INT32> starts, sizes;
1641     RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
1642     RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
1643     if (starts.data.size() != sizes.data.size()) {
1644       return absl::InvalidArgumentError("Starts amount != sizes amount.");
1645     }
1646     BHWC bhwc_starts(0, 0, 0, 0);
1647     BHWC bhwc_sizes = input->tensor.shape;
1648     if (input_dims == 4) {
1649       // input in BHWC layout
1650       if (starts.data.size() == 4) {
1651         bhwc_starts.b = starts.data[0];
1652         bhwc_starts.h = starts.data[1];
1653         bhwc_starts.w = starts.data[2];
1654         bhwc_starts.c = starts.data[3];
1655         bhwc_sizes.b = sizes.data[0];
1656         bhwc_sizes.h = sizes.data[1];
1657         bhwc_sizes.w = sizes.data[2];
1658         bhwc_sizes.c = sizes.data[3];
1659       } else if (starts.data.size() == 3) {
1660         // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
1661         bhwc_starts.h = starts.data[0];
1662         bhwc_starts.w = starts.data[1];
1663         bhwc_starts.c = starts.data[2];
1664         bhwc_sizes.h = sizes.data[0];
1665         bhwc_sizes.w = sizes.data[1];
1666         bhwc_sizes.c = sizes.data[2];
1667       } else {
1668         return absl::UnimplementedError(
1669             "Slicing is supported for 3 or 4 dimensional tensors only.");
1670       }
1671     } else if (input_dims == 3) {
1672       // input in BWC layout
1673       if (starts.data.size() == 3) {
1674         bhwc_starts.b = starts.data[0];
1675         bhwc_starts.w = starts.data[1];
1676         bhwc_starts.c = starts.data[2];
1677         bhwc_sizes.b = sizes.data[0];
1678         bhwc_sizes.w = sizes.data[1];
1679         bhwc_sizes.c = sizes.data[2];
1680       } else {
1681         return absl::UnimplementedError(
1682             "Slicing is supported for 3 or 4 dimensional tensors only.");
1683       }
1684     } else {
1685       return absl::UnimplementedError(
1686           "Slicing is supported for 3 or 4 dimensional tensors only.");
1687     }
1688     const auto& in_shape = input->tensor.shape;
1689     if (bhwc_sizes.b == -1) {
1690       bhwc_sizes.b = in_shape.b - bhwc_starts.b;
1691     }
1692     if (bhwc_sizes.h == -1) {
1693       bhwc_sizes.h = in_shape.h - bhwc_starts.h;
1694     }
1695     if (bhwc_sizes.w == -1) {
1696       bhwc_sizes.w = in_shape.w - bhwc_starts.w;
1697     }
1698     if (bhwc_sizes.c == -1) {
1699       bhwc_sizes.c = in_shape.c - bhwc_starts.c;
1700     }
1701     attr.starts = bhwc_starts;
1702     attr.ends =
1703         BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
1704              bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
1705     RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
1706 
1707     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1708     if ((attr.ends.b - attr.starts.b) != out_shape.b) {
1709       return absl::UnimplementedError("Output batch don't match");
1710     }
1711     if ((attr.ends.h - attr.starts.h) != out_shape.h) {
1712       return absl::UnimplementedError("Output height doesn't match");
1713     }
1714     if ((attr.ends.w - attr.starts.w) != out_shape.w) {
1715       return absl::UnimplementedError("Output width doesn't match");
1716     }
1717     if ((attr.ends.c - attr.starts.c) != out_shape.c) {
1718       return absl::UnimplementedError("Output channels don't match");
1719     }
1720     node->operation.attributes = attr;
1721     return absl::OkStatus();
1722   }
1723 
1724  private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1725   absl::Status UpdateIfNegative(const BHWC& input_shape,
1726                                 SliceAttributes* attr) {
1727     if (attr->ends.h < 0) {
1728       attr->ends.h = input_shape.h + attr->ends.h;
1729     }
1730     if (attr->ends.w < 0) {
1731       attr->ends.w = input_shape.w + attr->ends.w;
1732     }
1733     if (attr->ends.c < 0) {
1734       attr->ends.c = input_shape.c + attr->ends.c;
1735     }
1736     if (attr->ends.b < 0) {
1737       attr->ends.b = input_shape.b + attr->ends.b;
1738     }
1739     return absl::OkStatus();
1740   }
1741 };
1742 
1743 class SoftmaxOperationParser : public TFLiteOperationParser {
1744  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1745   absl::Status IsSupported(const TfLiteContext* context,
1746                            const TfLiteNode* tflite_node,
1747                            const TfLiteRegistration* registration) final {
1748     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1749     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1750                                        /*runtime_inputs=*/1, /*outputs=*/1));
1751     const TfLiteSoftmaxParams* tf_options;
1752     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1753     if (tf_options->beta != 1) {
1754       // TODO(eignasheva): figure out, what's wrong with softmax.
1755       return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1756     }
1757     return absl::OkStatus();
1758   }
1759 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1760   absl::Status Parse(const TfLiteNode* tflite_node,
1761                      const TfLiteRegistration* registration,
1762                      GraphFloat32* graph, ObjectReader* reader) final {
1763     Node* node = graph->NewNode();
1764     node->operation.type = ToString(OperationType::SOFTMAX);
1765     RETURN_IF_ERROR(reader->AddInput(node, 0));
1766     RETURN_IF_ERROR(reader->AddOutputs(node));
1767 
1768     const TfLiteSoftmaxParams* tf_options;
1769     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1770     if (tf_options->beta != 1) {
1771       // there is multiply by scalar operation fused in softmax. Make a layer
1772       // out of it before softmax.
1773       return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1774       // auto mul_node = reader->NewPassthroughNode(node);
1775       // mul_node->operation.type = ToString(OperationType::MUL);
1776     }
1777     SoftmaxAttributes attr;
1778     attr.axis = Axis::CHANNELS;  // always by channels
1779     node->operation.attributes = attr;
1780     return absl::OkStatus();
1781   }
1782 };
1783 
1784 class SpaceToDepthOperationParser : public TFLiteOperationParser {
1785  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1786   absl::Status IsSupported(const TfLiteContext* context,
1787                            const TfLiteNode* tflite_node,
1788                            const TfLiteRegistration* registration) final {
1789     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1790     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1791                                        /*runtime_inputs=*/1, /*outputs=*/1));
1792     // TODO(impjdi): Dims check.
1793     const TfLiteSpaceToDepthParams* s2d_params;
1794     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
1795     if (s2d_params->block_size == 1) {
1796       return absl::InvalidArgumentError(
1797           "SPACE_TO_DEPTH block_size = 1 is a no-op.");
1798     }
1799     if (s2d_params->block_size < 1) {
1800       return absl::InvalidArgumentError(
1801           "SPACE_TO_DEPTH block_size must be > 1.");
1802     }
1803     return absl::OkStatus();
1804   }
1805 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1806   absl::Status Parse(const TfLiteNode* tflite_node,
1807                      const TfLiteRegistration* registration,
1808                      GraphFloat32* graph, ObjectReader* reader) final {
1809     Node* node = graph->NewNode();
1810     node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
1811     RETURN_IF_ERROR(reader->AddInput(node, 0));
1812     RETURN_IF_ERROR(reader->AddOutputs(node));
1813     const TfLiteSpaceToDepthParams* tf_options;
1814     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1815     SpaceToDepthAttributes attr;
1816     attr.block_size = tf_options->block_size;
1817     node->operation.attributes = attr;
1818     return absl::OkStatus();
1819   }
1820 };
1821 
1822 class SplitVOperationParser : public TFLiteOperationParser {
1823  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1824   absl::Status IsSupported(const TfLiteContext* context,
1825                            const TfLiteNode* tflite_node,
1826                            const TfLiteRegistration* registration) final {
1827     const TfLiteSplitVParams* split_params;
1828     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1829     if (split_params->num_splits == 1) {
1830       return absl::InvalidArgumentError(
1831           "SplitV with num_splits = 1 is a no-op.");
1832     }
1833     return absl::OkStatus();
1834   }
1835 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1836   absl::Status Parse(const TfLiteNode* tflite_node,
1837                      const TfLiteRegistration* registration,
1838                      GraphFloat32* graph, ObjectReader* reader) final {
1839     const TfLiteTensor* input = reader->GetInputTensor(0);
1840     const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
1841     SplitAttributes attr;
1842     RETURN_IF_ERROR(
1843         ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1844 
1845     Node* node = graph->NewNode();
1846     node->operation.type = ToString(OperationType::SPLIT);
1847     node->operation.attributes = attr;
1848     RETURN_IF_ERROR(reader->AddInput(node, 0));
1849     for (int i = 0; i < tflite_node->outputs->size; ++i) {
1850       RETURN_IF_ERROR(reader->AddOutput(node, i));
1851     }
1852     return absl::OkStatus();
1853   }
1854 };
1855 
1856 class StridedSliceOperationParser : public TFLiteOperationParser {
1857  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1858   absl::Status IsSupported(const TfLiteContext* context,
1859                            const TfLiteNode* tflite_node,
1860                            const TfLiteRegistration* registration) final {
1861     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1862     const TfLiteStridedSliceParams* tf_options;
1863     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1864     RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1865 
1866     if (tflite_node->inputs->size < 4) {
1867       return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
1868     }
1869     const TfLiteTensor* input = GetInput(context, tflite_node, 0);
1870     if (input->dims->size != 3 && input->dims->size != 4) {
1871       return absl::UnimplementedError(
1872           "STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
1873     }
1874     return absl::OkStatus();
1875   }
1876 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1877   absl::Status Parse(const TfLiteNode* tflite_node,
1878                      const TfLiteRegistration* registration,
1879                      GraphFloat32* graph, ObjectReader* reader) final {
1880     Node* node = graph->NewNode();
1881     node->operation.type = ToString(OperationType::SLICE);
1882     RETURN_IF_ERROR(reader->AddOutputs(node));
1883     Value* input;
1884     RETURN_IF_ERROR(reader->ReadValue(0, &input));
1885     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1886 
1887     Tensor<Linear, DataType::INT32> tmp;
1888     RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
1889 
1890     bool read_without_batch = tmp.data.size() == 3;
1891     bool read_with_batch = tmp.data.size() == 4;
1892     if (!read_without_batch && !read_with_batch) {
1893       // Error: Must be catched in IsSupported()
1894       return absl::UnimplementedError(
1895           "Slicing is supported for 3 or 4 dimensional tensors only.");
1896     }
1897 
1898     const TfLiteStridedSliceParams* tf_options;
1899     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1900     RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1901 
1902     auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1903 
1904     SliceAttributes attr;
1905     if (read_without_batch) {
1906       RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
1907                                               input->tensor.shape, &attr));
1908     }
1909     if (read_with_batch) {
1910       RETURN_IF_ERROR(
1911           ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
1912     }
1913     if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
1914         attr.strides.c == 0) {
1915       return absl::InvalidArgumentError("stride values must be non-zero");
1916     }
1917     if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
1918         attr.strides.c < 0) {
1919       return absl::UnimplementedError("Reverse slices are not supported.");
1920     }
1921     if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
1922         out_shape.b) {
1923       return absl::UnimplementedError("Output batch don't match");
1924     }
1925     if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
1926         out_shape.h) {
1927       return absl::UnimplementedError("Output height doesn't match");
1928     }
1929     if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
1930         out_shape.w) {
1931       return absl::UnimplementedError("Output width doesn't match");
1932     }
1933     if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
1934         out_shape.c) {
1935       return absl::UnimplementedError("Output channels don't match");
1936     }
1937     node->operation.attributes = attr;
1938     return absl::OkStatus();
1939   }
1940 
1941  private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)1942   absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
1943                               const BHWC& input_shape, int ignore_b,
1944                               int ignore_h, int ignore_w, int ignore_c,
1945                               SliceAttributes* attr) {
1946     if (tf_options->begin_mask & ignore_h) {
1947       attr->starts.h = 0;
1948     }
1949     if (tf_options->begin_mask & ignore_w) {
1950       attr->starts.w = 0;
1951     }
1952     if (tf_options->begin_mask & ignore_c) {
1953       attr->starts.c = 0;
1954     }
1955     if (tf_options->begin_mask & ignore_b) {
1956       attr->starts.b = 0;
1957     }
1958 
1959     if (tf_options->end_mask & ignore_h) {
1960       attr->ends.h = input_shape.h;
1961     }
1962     if (tf_options->end_mask & ignore_w) {
1963       attr->ends.w = input_shape.w;
1964     }
1965     if (tf_options->end_mask & ignore_c) {
1966       attr->ends.c = input_shape.c;
1967     }
1968     if (tf_options->end_mask & ignore_b) {
1969       attr->ends.b = input_shape.b;
1970     }
1971     return absl::OkStatus();
1972   }
1973 
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1974   absl::Status UpdateIfNegative(const BHWC& input_shape,
1975                                 SliceAttributes* attr) {
1976     if (attr->ends.h < 0) {
1977       attr->ends.h = input_shape.h + attr->ends.h;
1978     }
1979     if (attr->ends.w < 0) {
1980       attr->ends.w = input_shape.w + attr->ends.w;
1981     }
1982     if (attr->ends.c < 0) {
1983       attr->ends.c = input_shape.c + attr->ends.c;
1984     }
1985     if (attr->ends.b < 0) {
1986       attr->ends.b = input_shape.b + attr->ends.b;
1987     }
1988 
1989     if (attr->starts.h < 0) {
1990       attr->starts.h = input_shape.h + attr->starts.h;
1991     }
1992     if (attr->starts.w < 0) {
1993       attr->starts.w = input_shape.w + attr->starts.w;
1994     }
1995     if (attr->starts.c < 0) {
1996       attr->starts.c = input_shape.c + attr->starts.c;
1997     }
1998     if (attr->starts.b < 0) {
1999       attr->starts.b = input_shape.b + attr->starts.b;
2000     }
2001 
2002     return absl::OkStatus();
2003   }
2004 
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2005   absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2006                                     const TfLiteStridedSliceParams* tf_options,
2007                                     const BHWC& input_shape,
2008                                     SliceAttributes* attr) {
2009     auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2010       Tensor<Linear, DataType::INT32> t;
2011       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2012       *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2013       return absl::OkStatus();
2014     };
2015 
2016     RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2017     RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2018     RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2019     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2020     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2021     return absl::OkStatus();
2022   }
2023 
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2024   absl::Status ReadAttribsWithoutBatch(
2025       const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2026       const BHWC& input_shape, SliceAttributes* attr) {
2027     auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2028       Tensor<Linear, DataType::INT32> t;
2029       RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2030       *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2031       return absl::OkStatus();
2032     };
2033 
2034     RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2035     RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2036     RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2037     RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2038     RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2039     attr->starts.b = 0;
2040     attr->ends.b = input_shape.b;
2041     attr->strides.b = 1;
2042     return absl::OkStatus();
2043   }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2044   absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2045     if (tf_options->ellipsis_mask) {
2046       return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2047     }
2048     if (tf_options->new_axis_mask) {
2049       return absl::UnimplementedError("Slice does not support new_axis_mask.");
2050     }
2051     if (tf_options->shrink_axis_mask) {
2052       return absl::UnimplementedError(
2053           "Slice does not support shrink_axis_mask parameter. ");
2054     }
2055     return absl::OkStatus();
2056   }
2057 };
2058 
2059 // Builtin op version of TRANSPOSE_CONV.
2060 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2061  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2062   absl::Status IsSupported(const TfLiteContext* context,
2063                            const TfLiteNode* tflite_node,
2064                            const TfLiteRegistration* registration) final {
2065     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2066     const int runtime_inputs =
2067         GetNumberOfRuntimeInputsForNode(context, tflite_node);
2068     if (runtime_inputs > 2) {
2069       return absl::InternalError(
2070           absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
2071                        runtime_inputs, " runtime inputs."));
2072     }
2073     const int runtime_outputs = NumOutputs(tflite_node);
2074     if (runtime_outputs != 1) {
2075       return absl::InternalError(
2076           absl::StrCat("Expected 1 output tensor(s), but node has ",
2077                        runtime_outputs, " runtime outputs."));
2078     }
2079     if (runtime_inputs == 1) {
2080       RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
2081     }
2082     const TfLiteTransposeConvParams* tf_options;
2083     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2084     RETURN_IF_ERROR(
2085         CheckStrides(tf_options->stride_height, tf_options->stride_width));
2086     return absl::OkStatus();
2087   }
2088 
2089   // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2090   // input, and an optional bias) and allows configurable padding & stride.
2091   // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2092   absl::Status Parse(const TfLiteNode* tflite_node,
2093                      const TfLiteRegistration* registration,
2094                      GraphFloat32* graph, ObjectReader* reader) final {
2095     auto* node = graph->NewNode();
2096     node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2097     Value* input;
2098     RETURN_IF_ERROR(reader->ReadValue(2, &input));
2099     RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2100     RETURN_IF_ERROR(reader->AddOutputs(node));
2101 
2102     const TfLiteTransposeConvParams* tf_options;
2103     RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2104 
2105     ConvolutionTransposedAttributes attr;
2106     attr.stride = tf_options
2107                       ? HW(tf_options->stride_height, tf_options->stride_width)
2108                       : HW(1, 1);
2109     const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2110     if (runtime_inputs == 2) {
2111       RETURN_IF_ERROR(reader->AddInput(node, 1));
2112       auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2113       attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2114                                 weights_shape.w, weights_shape.c);
2115     } else {  // runtime_inputs == 1;
2116       RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2117     }
2118     reader->ReadTensor(3, &attr.bias).IgnoreError();  // bias is optional
2119 
2120     UpdatePadding(tf_options->padding,
2121                   graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2122     node->operation.attributes = std::move(attr);
2123     return absl::OkStatus();
2124   }
2125 };
2126 
2127 class TransposeOperationParser : public TFLiteOperationParser {
2128  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2129   absl::Status IsSupported(const TfLiteContext* context,
2130                            const TfLiteNode* tflite_node,
2131                            const TfLiteRegistration* registration) final {
2132     RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2133     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
2134                                        /*runtime_inputs=*/1, /*outputs=*/1));
2135     return absl::OkStatus();
2136   }
2137 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2138   absl::Status Parse(const TfLiteNode* tflite_node,
2139                      const TfLiteRegistration* registration,
2140                      GraphFloat32* graph, ObjectReader* reader) final {
2141     Node* node = graph->NewNode();
2142     node->operation.type = ToString(OperationType::TRANSPOSE);
2143     RETURN_IF_ERROR(reader->AddInput(node, 0));
2144     RETURN_IF_ERROR(reader->AddOutputs(node));
2145 
2146     TransposeAttributes attr;
2147     Tensor<Linear, DataType::INT32> perm;
2148     RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2149     std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2150                                          {Axis::HEIGHT, 1},
2151                                          {Axis::WIDTH, 2},
2152                                          {Axis::CHANNELS, 3}};
2153     if (perm.data.size() == 4) {
2154       attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2155     } else if (perm.data.size() == 3) {
2156       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2157                                          Axis::CHANNELS};
2158       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2159       attr.perm.h = 1;
2160       attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2161       attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2162     } else if (perm.data.size() == 2) {
2163       std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2164       attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2165       attr.perm.h = 1;
2166       attr.perm.w = 2;
2167       attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2168     } else {
2169       return absl::InvalidArgumentError(
2170           "Permutation for transpose is invalid.");
2171     }
2172 
2173     node->operation.attributes = attr;
2174     return absl::OkStatus();
2175   }
2176 };
2177 
2178 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2179 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2180  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2181   absl::Status IsSupported(const TfLiteContext* context,
2182                            const TfLiteNode* tflite_node,
2183                            const TfLiteRegistration* registration) final {
2184     return absl::OkStatus();
2185   }
2186 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2187   absl::Status Parse(const TfLiteNode* tflite_node,
2188                      const TfLiteRegistration* registration,
2189                      GraphFloat32* graph, ObjectReader* reader) final {
2190     auto* node = graph->NewNode();
2191     node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2192     RETURN_IF_ERROR(reader->AddInput(node, 0));
2193     RETURN_IF_ERROR(reader->AddOutputs(node));
2194 
2195     BatchToSpaceAttributes bs_attr;
2196     Tensor<Linear, DataType::INT32> block;
2197     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2198     if (block.shape.v != 2) {
2199       return absl::InternalError("Space has to be HxW.");
2200     }
2201     bs_attr.block.h = block.data[0];
2202     bs_attr.block.w = block.data[1];
2203 
2204     Tensor<HW, DataType::INT32> crop;
2205     RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2206     auto crop_shape = crop.shape;
2207     if (crop_shape.h != 2 && crop_shape.w != 2) {
2208       return absl::InternalError("Space has to be HxW.");
2209     }
2210 
2211     bs_attr.crop.prepended.h = crop.data[0];
2212     bs_attr.crop.prepended.w = crop.data[2];
2213 
2214     bs_attr.crop.appended.h = crop.data[1];
2215     bs_attr.crop.appended.w = crop.data[3];
2216 
2217     node->operation.attributes = std::move(bs_attr);
2218     return absl::OkStatus();
2219   }
2220 };
2221 
2222 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2223  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2224   absl::Status IsSupported(const TfLiteContext* context,
2225                            const TfLiteNode* tflite_node,
2226                            const TfLiteRegistration* registration) final {
2227     return absl::OkStatus();
2228   }
2229 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2230   absl::Status Parse(const TfLiteNode* tflite_node,
2231                      const TfLiteRegistration* registration,
2232                      GraphFloat32* graph, ObjectReader* reader) final {
2233     auto* node = graph->NewNode();
2234     node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2235     RETURN_IF_ERROR(reader->AddInput(node, 0));
2236     RETURN_IF_ERROR(reader->AddOutputs(node));
2237     SpaceToBatchAttributes sb_attr;
2238     Tensor<Linear, DataType::INT32> block;
2239     RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2240     if (block.shape.v != 2) {
2241       return absl::InternalError("Space has to be HxW.");
2242     }
2243     sb_attr.block.h = block.data[0];
2244     sb_attr.block.w = block.data[1];
2245 
2246     Tensor<HW, DataType::INT32> padding;
2247     RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2248     auto padding_shape = padding.shape;
2249 
2250     if (padding_shape.h != 2 && padding_shape.w != 2) {
2251       return absl::InternalError("Space has to be HxW.");
2252     }
2253 
2254     sb_attr.padding.prepended.h = padding.data[0];
2255     sb_attr.padding.prepended.w = padding.data[2];
2256 
2257     sb_attr.padding.appended.h = padding.data[1];
2258     sb_attr.padding.appended.w = padding.data[3];
2259 
2260     node->operation.attributes = std::move(sb_attr);
2261     return absl::OkStatus();
2262   }
2263 };
2264 
2265 class MeanOperationParser : public TFLiteOperationParser {
2266  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2267   absl::Status IsSupported(const TfLiteContext* context,
2268                            const TfLiteNode* tflite_node,
2269                            const TfLiteRegistration* registration) final {
2270     RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
2271                                        /*runtime_inputs=*/1,
2272                                        /*outputs=*/1));
2273 
2274     auto* axes = &context->tensors[tflite_node->inputs->data[1]];
2275     if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
2276       return absl::UnimplementedError("Mean has unsupported tensor for axes");
2277     }
2278     return absl::OkStatus();
2279   }
2280 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2281   absl::Status Parse(const TfLiteNode* tflite_node,
2282                      const TfLiteRegistration* registration,
2283                      GraphFloat32* graph, ObjectReader* reader) final {
2284     auto* node = graph->NewNode();
2285     node->operation.type = ToString(OperationType::MEAN);
2286     RETURN_IF_ERROR(reader->AddInput(node, 0));
2287     RETURN_IF_ERROR(reader->AddOutputs(node));
2288 
2289     MeanAttributes attr;
2290     const TfLiteTensor* input = reader->GetInputTensor(0);
2291     const TfLiteTensor* axes = reader->GetInputTensor(1);
2292     for (int i = 0; i < NumElements(axes->dims); i++) {
2293       Axis axis;
2294       RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2295       attr.dims.insert(axis);
2296     }
2297     node->operation.attributes = attr;
2298     return absl::OkStatus();
2299   }
2300 };
2301 
2302 class UnsupportedOperationParser : public TFLiteOperationParser {
2303  public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2304   absl::Status IsSupported(const TfLiteContext* context,
2305                            const TfLiteNode* tflite_node,
2306                            const TfLiteRegistration* registration) final {
2307     return absl::UnimplementedError("Operation is not supported.");
2308   }
2309 
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2310   absl::Status Parse(const TfLiteNode* tflite_node,
2311                      const TfLiteRegistration* registration,
2312                      GraphFloat32* graph, ObjectReader* reader) final {
2313     return absl::UnimplementedError("Operation is not supported.");
2314   }
2315 };
2316 
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops=false)2317 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2318     const TfLiteRegistration* registration, bool allow_quant_ops = false) {
2319   const auto builtin_code = registration->builtin_code;
2320   switch (builtin_code) {
2321     case kTfLiteBuiltinAbs:
2322       return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2323     case kTfLiteBuiltinAdd:
2324       return std::make_unique<AddOperationParser>();
2325     case kTfLiteBuiltinAveragePool2d:
2326       return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2327     case kTfLiteBuiltinBatchMatmul:
2328       return std::make_unique<BatchedMatMulOperationParser>();
2329     case kTfLiteBuiltinConcatenation:
2330       return std::make_unique<ConcatenationOperationParser>();
2331     case kTfLiteBuiltinConv2d:
2332       return std::make_unique<Conv2DOperationParser>();
2333     case kTfLiteBuiltinCos:
2334       return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2335     case kTfLiteBuiltinDepthwiseConv2d:
2336       return std::make_unique<DepthwiseConvolutionOperationParser>();
2337     case kTfLiteBuiltinDequantize:
2338       if (allow_quant_ops) {
2339         return std::make_unique<DequantizeOperationParser>();
2340       }
2341       break;
2342     case kTfLiteBuiltinDiv:
2343       return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2344     case kTfLiteBuiltinElu:
2345       return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2346     case kTfLiteBuiltinExp:
2347       return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2348     case kTfLiteBuiltinFullyConnected:
2349       return std::make_unique<FullyConnectedOperationParser>();
2350     case kTfLiteBuiltinHardSwish:
2351       return std::make_unique<HardSwishOperationParser>();
2352     case kTfLiteBuiltinLogistic:
2353       return std::make_unique<ElementwiseOperationParser>(
2354           OperationType::SIGMOID);
2355     case kTfLiteBuiltinLog:
2356       return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
2357     case kTfLiteBuiltinLstm:
2358       return std::make_unique<LSTMOperationParser>();
2359     case kTfLiteBuiltinMaximum:
2360       return std::make_unique<ElementwiseOperationParser>(
2361           OperationType::MAXIMUM);
2362     case kTfLiteBuiltinMaxPool2d:
2363       return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2364     case kTfLiteBuiltinMean:
2365       return std::make_unique<MeanOperationParser>();
2366     case kTfLiteBuiltinMinimum:
2367       return std::make_unique<ElementwiseOperationParser>(
2368           OperationType::MINIMUM);
2369     case kTfLiteBuiltinMirrorPad:
2370       return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
2371     case kTfLiteBuiltinMul:
2372       return std::make_unique<MulOperationParser>();
2373     case kTfLiteBuiltinNeg:
2374       return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
2375     case kTfLiteBuiltinPack:
2376       return std::make_unique<PackOperationParser>();
2377     case kTfLiteBuiltinPad:
2378       return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
2379     case kTfLiteBuiltinPow:
2380       return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
2381     case kTfLiteBuiltinReduceMax:
2382       return std::make_unique<ReduceOperationParser>(
2383           OperationType::REDUCE_MAXIMUM);
2384     case kTfLiteBuiltinReduceMin:
2385       return std::make_unique<ReduceOperationParser>(
2386           OperationType::REDUCE_MINIMUM);
2387     case kTfLiteBuiltinReduceProd:
2388       return std::make_unique<ReduceOperationParser>(
2389           OperationType::REDUCE_PRODUCT);
2390     case kTfLiteBuiltinQuantize:
2391       if (allow_quant_ops) {
2392         return std::make_unique<QuantizeOperationParser>();
2393       }
2394       break;
2395     case kTfLiteBuiltinRelu:
2396       return std::make_unique<ReLUOperationParser>(0);
2397     case kTfLiteBuiltinRelu6:
2398       return std::make_unique<ReLUOperationParser>(6);
2399     case kTfLiteBuiltinLeakyRelu:
2400       return std::make_unique<ReLUOperationParser>(0);
2401     case kTfLiteBuiltinPrelu:
2402       return std::make_unique<PReLUOperationParser>();
2403     case kTfLiteBuiltinReshape:
2404       return std::make_unique<ReshapeOperationParser>();
2405     case kTfLiteBuiltinResizeBilinear:
2406       return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
2407     case kTfLiteBuiltinResizeNearestNeighbor:
2408       return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
2409     case kTfLiteBuiltinRsqrt:
2410       return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
2411     case kTfLiteBuiltinSin:
2412       return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
2413     case kTfLiteBuiltinSlice:
2414       return std::make_unique<SliceOperationParser>();
2415     case kTfLiteBuiltinSoftmax:
2416       return std::make_unique<SoftmaxOperationParser>();
2417     case kTfLiteBuiltinSpaceToDepth:
2418       return std::make_unique<SpaceToDepthOperationParser>();
2419     case kTfLiteBuiltinSplitV:
2420       return std::make_unique<SplitVOperationParser>();
2421     case kTfLiteBuiltinSqrt:
2422       return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
2423     case kTfLiteBuiltinSquare:
2424       return std::make_unique<ElementwiseOperationParser>(
2425           OperationType::SQUARE);
2426     case kTfLiteBuiltinSquaredDifference:
2427       return std::make_unique<ElementwiseOperationParser>(
2428           OperationType::SQUARED_DIFF);
2429     case kTfLiteBuiltinStridedSlice:
2430       return std::make_unique<StridedSliceOperationParser>();
2431     case kTfLiteBuiltinSub:
2432       return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
2433     case kTfLiteBuiltinSum:
2434       return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
2435     case kTfLiteBuiltinTanh:
2436       return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
2437     case kTfLiteBuiltinTranspose:
2438       return std::make_unique<TransposeOperationParser>();
2439     case kTfLiteBuiltinTransposeConv:
2440       return std::make_unique<TransposeConvBuiltinOperationParser>();
2441     case kTfLiteBuiltinCustom:
2442       return NewCustomOperationParser(registration->custom_name);
2443   }
2444   return std::make_unique<UnsupportedOperationParser>();
2445 }
2446 
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false)2447 absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
2448                          const TfLiteRegistration* registration,
2449                          bool allow_quant_ops = false) {
2450   return NewOperationParser(registration, allow_quant_ops)
2451       ->IsSupported(context, node, registration);
2452 }
2453 
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,bool allow_quant_ops=false)2454 bool IsAllAllowedTensors(TfLiteContext* context,
2455                          const TfLiteIntArray* tensor_indices,
2456                          bool allow_quant_ops = false) {
2457   for (int i = 0; i < tensor_indices->size; ++i) {
2458     int tensor_idx = tensor_indices->data[i];
2459     if (tensor_idx == kTfLiteOptionalTensor) continue;
2460     const TfLiteTensor* t = &context->tensors[tensor_idx];
2461     bool type_supported =
2462         (t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16);
2463     if (allow_quant_ops) {
2464       // Since we only check non-constant tensors, type cannot be Int32.
2465       type_supported =
2466           type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8;
2467     }
2468     if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2469       return false;
2470     }
2471   }
2472   return true;
2473 }
2474 }  // namespace
2475 
2476 // TODO(impjdi): Check number of input/output tensors and their dimensions.
2477 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions)2478 TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
2479                                 int max_delegated_partitions) {
2480   delegates::IsNodeSupportedFn node_supported_fn =
2481       [=](TfLiteContext* context, TfLiteNode* node,
2482           TfLiteRegistration* registration,
2483           std::string* unsupported_details) -> bool {
2484     const auto status =
2485         IsSupported(context, node, registration, allow_quant_ops);
2486     if (!status.ok()) {
2487       if (unsupported_details) {
2488         *unsupported_details = std::string(status.message());
2489       }
2490       return false;
2491     }
2492 
2493     if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) ||
2494         !IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) {
2495       if (unsupported_details) {
2496         *unsupported_details =
2497             "OP is supported, but tensor type isn't matched!";
2498       }
2499       return false;
2500     }
2501     return true;
2502   };
2503 
2504   delegates::FP16GraphPartitionHelper partition_helper(context,
2505                                                        node_supported_fn);
2506   std::set<std::string> unsupported_nodes_info;
2507   if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
2508     return TfLiteIntArrayCreate(0);
2509   }
2510 
2511   // By default, we simply get 1st largest partition as 'max_delegate_partions'
2512   // is set to 1 by default.
2513   std::vector<int> ops_to_replace =
2514       partition_helper.GetNodesOfFirstNLargestPartitions(
2515           max_delegated_partitions);
2516 
2517   if (!unsupported_nodes_info.empty()) {
2518     std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
2519     std::string error_message = absl::StrCat(
2520         "Following operations are not supported by GPU delegate:\n",
2521         unsupported, "\n");
2522     if (!ops_to_replace.empty()) {
2523       absl::StrAppend(
2524           &error_message, ops_to_replace.size(),
2525           " operations will run on the GPU, and the remaining ",
2526           partition_helper.num_total_nodes() - ops_to_replace.size());
2527     } else {
2528       absl::StrAppend(&error_message,
2529                       "No operations will run on the GPU, and all ",
2530                       partition_helper.num_total_nodes());
2531     }
2532     absl::StrAppend(&error_message, " operations will run on the CPU.");
2533     TF_LITE_KERNEL_LOG(context, error_message.c_str());
2534   }
2535   return ConvertVectorToTfLiteIntArray(ops_to_replace);
2536 }
2537 
2538 // Creates inputs and outputs passed by io_tensors parameters in the resulting
2539 // graph. We force it to make sure that delegated subgraph has same order of
2540 // inputs and outputs with the original one. When delegated model is built from
2541 // the tflite model representation tensors are created lazily, so there is no
2542 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,TfLiteIntArray * io_tensors,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)2543 absl::Status PrecreateIOTensors(
2544     TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
2545     absl::flat_hash_map<int, int>* quant_conversion_map,
2546     absl::flat_hash_map<int, Value*>* tensor_to_value) {
2547   for (int i = 0; i < io_tensors->size; ++i) {
2548     const int tensor_index = io_tensors->data[i];
2549     const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
2550     if (tflite::IsConstantTensor(&tflite_tensor)) continue;
2551     RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
2552         context, tensor_to_value, quant_conversion_map, graph, tensor_index));
2553   }
2554   return absl::OkStatus();
2555 }
2556 
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)2557 absl::Status CopyVariableTensorOutputs(
2558     TfLiteNode* tflite_node, TfLiteRegistration* registration,
2559     GraphFloat32* graph, ObjectReader& reader,
2560     const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
2561   absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
2562       new_variable_tensor_values);
2563   // Retrieve the final value id for the variable input tensors.
2564   for (int i = 0; i < tflite_node->inputs->size; i++) {
2565     int tensor_idx = tflite_node->inputs->data[i];
2566     Value* value;
2567     if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
2568     if (value->tensor.is_variable_input) {
2569       if (new_variable_tensor_values_copy.find(i) ==
2570           new_variable_tensor_values_copy.end()) {
2571         return absl::InvalidArgumentError(
2572             absl::StrCat(GetOpNameByRegistration(*registration),
2573                          " did not provide a new value for the variable input "
2574                          "tensor with index ",
2575                          tensor_idx));
2576       } else {
2577         Node* node = graph->NewNode();
2578         node->operation.type = ToString(OperationType::COPY);
2579         RETURN_IF_ERROR(graph->AddConsumer(
2580             node->id, new_variable_tensor_values_copy.at(i)));
2581         RETURN_IF_ERROR(reader.AddUpdate(node, i));
2582         new_variable_tensor_values_copy.erase(
2583             new_variable_tensor_values_copy.find(i));
2584       }
2585     }
2586   }
2587   if (!new_variable_tensor_values_copy.empty()) {
2588     return absl::InvalidArgumentError(
2589         "More input variable tensors asked to be copied than present on the "
2590         "node");
2591   }
2592   return absl::OkStatus();
2593 }
2594 
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2595 absl::Status BuildModel(TfLiteContext* context,
2596                         const TfLiteDelegateParams* delegate_params,
2597                         GraphFloat32* graph,
2598                         absl::flat_hash_map<int, int>* quant_conversion_map) {
2599   std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
2600   std::vector<int> tflite_nodes;
2601   for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
2602     TfLiteNode* tflite_node = nullptr;
2603     TfLiteRegistration* registration = nullptr;
2604     RETURN_IF_ERROR(GetNodeAndRegistration(
2605         context, delegate_params->nodes_to_replace->data[i], &tflite_node,
2606         &registration));
2607     if (registration->builtin_code == kTfLiteBuiltinDequantize &&
2608         context->tensors[tflite_node->inputs->data[0]].type ==
2609             TfLiteType::kTfLiteFloat16) {
2610       // Ignore Fp16 Dequantize nodes.
2611       continue;
2612     }
2613     auto op_parser = NewOperationParser(
2614         registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
2615     if (!op_parser) {
2616       return absl::UnimplementedError(
2617           absl::StrCat("Operation ", registration->builtin_code, "(",
2618                        registration->custom_name,
2619                        ") is not supported by TFLite GPU Delegate."));
2620     }
2621     operations.push_back(std::move(op_parser));
2622     tflite_nodes.push_back(i);
2623   }
2624   absl::flat_hash_map<int, Value*> tensor_to_value;
2625   std::vector<ValueId> variable_inputs_to_value_id;
2626   RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
2627                                      delegate_params->input_tensors,
2628                                      quant_conversion_map, &tensor_to_value));
2629   RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
2630                                      delegate_params->output_tensors,
2631                                      quant_conversion_map, &tensor_to_value));
2632   for (int i = 0; i < operations.size(); ++i) {
2633     TfLiteNode* tflite_node;
2634     TfLiteRegistration* registration;
2635     RETURN_IF_ERROR(GetNodeAndRegistration(
2636         context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
2637         &tflite_node, &registration));
2638     ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
2639                         quant_conversion_map);
2640     const auto status =
2641         operations[i]->Parse(tflite_node, registration, graph, &reader);
2642     if (!status.ok()) {
2643       return absl::InternalError(absl::StrCat(
2644           GetOpNameByRegistration(*registration), ": ", status.message()));
2645     }
2646 
2647     absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
2648         operations[i]->GetNewValueIdsForVariableInputNodes();
2649 
2650     RETURN_IF_ERROR(
2651         CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
2652                                   new_value_for_variable_input_tensors));
2653   }
2654 
2655   // Variable input tensors expect to be unchanged throughout model execution.
2656   // They need to be an output of the graph in order to have them unchanged.
2657   for (auto value_id : variable_inputs_to_value_id) {
2658     if (!graph->IsGraphOutput(value_id)) {
2659       return absl::InvalidArgumentError(
2660           absl::StrCat("Variable input tensors must be a graph output. Value ",
2661                        value_id, " is not a graph output"));
2662     }
2663   }
2664   return absl::OkStatus();
2665 }
2666 
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2667 absl::Status BuildFinalModel(
2668     TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2669     GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2670   RETURN_IF_ERROR(
2671       BuildModel(context, delegate_params, graph, quant_conversion_map));
2672 
2673   // Apply general transformations on the graph.
2674   NullTransformationReporter reporter;
2675   ModelTransformer transformer(graph, &reporter);
2676   if (!ApplyModelTransformations(&transformer)) {
2677     return absl::InternalError("Graph transformations failed");
2678   }
2679   return absl::OkStatus();
2680 }
2681 
2682 }  // namespace gpu
2683 }  // namespace tflite
2684