1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/base/attributes.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/status/status.h"
30 #include "absl/strings/str_cat.h"
31 #include "absl/strings/str_join.h"
32 #include "absl/strings/string_view.h"
33 #include "tensorflow/lite/builtin_ops.h"
34 #include "tensorflow/lite/c/builtin_op_data.h"
35 #include "tensorflow/lite/c/common.h"
36 #include "tensorflow/lite/delegates/gpu/common/custom_parsers.h"
37 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
38 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
39 #include "tensorflow/lite/delegates/gpu/common/model.h"
40 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
41 #include "tensorflow/lite/delegates/gpu/common/model_transformer.h"
42 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
43 #include "tensorflow/lite/delegates/gpu/common/operation_parser.h"
44 #include "tensorflow/lite/delegates/gpu/common/operations.h"
45 #include "tensorflow/lite/delegates/gpu/common/shape.h"
46 #include "tensorflow/lite/delegates/gpu/common/status.h"
47 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
48 #include "tensorflow/lite/delegates/gpu/common/transformations/model_transformations.h"
49 #include "tensorflow/lite/delegates/utils.h"
50 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
51 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
52 #include "tensorflow/lite/kernels/kernel_util.h"
53 #include "tensorflow/lite/util.h"
54
55 namespace tflite {
56 namespace gpu {
57 namespace {
58
GetFullyConnectedAttributes(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,FullyConnectedAttributes * attr)59 absl::Status GetFullyConnectedAttributes(int weights_tensor_id,
60 int bias_tensor_id,
61 ObjectReader* reader,
62 FullyConnectedAttributes* attr) {
63 Tensor<HW, DataType::FLOAT32> weights;
64 RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
65 attr->weights.data = std::move(weights.data);
66 attr->weights.id = weights.id;
67 attr->weights.shape.h = 1;
68 attr->weights.shape.w = 1;
69 attr->weights.shape.o = weights.shape.h;
70 attr->weights.shape.i = weights.shape.w;
71 reader->ReadTensor(bias_tensor_id, &attr->bias).IgnoreError(); // optional
72 return absl::OkStatus();
73 }
74
75 template <typename ParamsT>
RetrieveBuiltinData(const TfLiteNode * tflite_node,const ParamsT ** tf_options)76 absl::Status RetrieveBuiltinData(const TfLiteNode* tflite_node,
77 const ParamsT** tf_options) {
78 *tf_options = static_cast<const ParamsT*>(tflite_node->builtin_data);
79 if (!*tf_options) {
80 return absl::InternalError("Unable to retrieve builtin_data.");
81 }
82 return absl::OkStatus();
83 }
84
CheckDilation(int dilation_h,int dilation_w)85 absl::Status CheckDilation(int dilation_h, int dilation_w) {
86 if (dilation_h <= 0 || dilation_w <= 0) {
87 return absl::InvalidArgumentError(absl::StrCat(
88 "Incorrect dilation values: dilation_factor = ", dilation_h,
89 ", dilation_factor = ", dilation_w));
90 }
91 return absl::OkStatus();
92 }
93
CheckStridesAndDilation(int strides_h,int strides_w,int dilation_h,int dilation_w)94 absl::Status CheckStridesAndDilation(int strides_h, int strides_w,
95 int dilation_h, int dilation_w) {
96 RETURN_IF_ERROR(CheckStrides(strides_h, strides_w));
97 RETURN_IF_ERROR(CheckDilation(dilation_h, dilation_w));
98 return absl::OkStatus();
99 }
100
101 // Creates a simple node that holds tensor value.
NewConstNode(TensorFloat32 t,GraphFloat32 * graph,Value ** value)102 absl::Status NewConstNode(TensorFloat32 t, GraphFloat32* graph, Value** value) {
103 ConstTensorAttributes attr;
104 attr.tensor = std::move(t);
105 Node* node = graph->NewNode();
106 node->operation.attributes = attr;
107 node->operation.type = ToString(OperationType::CONSTANT);
108 *value = graph->NewValue();
109 RETURN_IF_ERROR(graph->SetProducer(node->id, (*value)->id));
110 // Keep data inside this tensor.
111 (*value)->tensor.ref = attr.tensor.id;
112 (*value)->tensor.type = attr.tensor.kType;
113 (*value)->tensor.shape = attr.tensor.shape;
114 return absl::OkStatus();
115 }
116
ParseInputsWithConstTensor(Node * node,ObjectReader * reader,TensorOrScalar * tensor_or_scalar)117 absl::Status ParseInputsWithConstTensor(Node* node, ObjectReader* reader,
118 TensorOrScalar* tensor_or_scalar) {
119 const std::string& opname = node->operation.type;
120
121 // Determine runtime/constant tensors.
122 const TfLiteTensor* input0 = reader->GetInputTensor(0);
123 if (!input0) {
124 return absl::InvalidArgumentError("Couldn't get the 1st input tensor for " +
125 opname);
126 }
127 const TfLiteTensor* input1 = reader->GetInputTensor(1);
128 if (!input1) {
129 return absl::InvalidArgumentError("Couldn't get the 2nd input tensor for " +
130 opname);
131 }
132 const bool constant_tensor0 = IsConstantTensor(input0);
133 const bool constant_tensor1 = IsConstantTensor(input1);
134 if (constant_tensor0 && constant_tensor1) {
135 return absl::InvalidArgumentError("No runtime input tensors for " + opname);
136 }
137 const bool runtime_tensor0 = !constant_tensor0;
138 const bool runtime_tensor1 = !constant_tensor1;
139
140 if (runtime_tensor0 && runtime_tensor1) {
141 RETURN_IF_ERROR(reader->AddInput(node, 0));
142 RETURN_IF_ERROR(reader->AddInput(node, 1));
143 } else {
144 int runtime_tensor = 0;
145 int constant_tensor = 1;
146 TfLiteIntArray* constant_dims = input1->dims;
147 if (constant_tensor0 && runtime_tensor1) {
148 runtime_tensor = 1;
149 constant_tensor = 0;
150 constant_dims = input0->dims;
151 }
152 RETURN_IF_ERROR(reader->AddInput(node, runtime_tensor));
153 if (constant_dims->size <= 0 || NumElements(constant_dims) == 1) {
154 Tensor<Scalar, DataType::FLOAT32> tensor;
155 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
156 *tensor_or_scalar = tensor.data[0];
157 } else {
158 if (CheckIfLinearConvertible(constant_dims).ok()) {
159 Tensor<Linear, DataType::FLOAT32> tensor;
160 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
161 *tensor_or_scalar = std::move(tensor);
162 } else {
163 Tensor<HWC, DataType::FLOAT32> tensor;
164 RETURN_IF_ERROR(reader->ReadTensor(constant_tensor, &tensor));
165 *tensor_or_scalar = std::move(tensor);
166 }
167 }
168 }
169 return absl::OkStatus();
170 }
171
172 class AddOperationParser : public TFLiteOperationParser {
173 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)174 absl::Status IsSupported(const TfLiteContext* context,
175 const TfLiteNode* tflite_node,
176 const TfLiteRegistration* registration) final {
177 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
178 if (tflite_node->inputs->size != 2) {
179 return absl::UnimplementedError("ADD requires two input tensors.");
180 }
181 // TODO(eignasheva): Add shapes check.
182
183 const TfLiteAddParams* tf_options;
184 return RetrieveBuiltinData(tflite_node, &tf_options);
185 }
186
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)187 absl::Status Parse(const TfLiteNode* tflite_node,
188 const TfLiteRegistration* registration,
189 GraphFloat32* graph, ObjectReader* reader) final {
190 // TFLite currently only supports 2 input ADDs. Thus, the logic below only
191 // considers 2 input cases. The underlying GPU shader programs can accept
192 // more inputs, but the logic below would have to be expanded.
193
194 Node* node = graph->NewNode();
195 node->operation.type = ToString(OperationType::ADD);
196 RETURN_IF_ERROR(reader->AddOutputs(node));
197 ElementwiseAttributes attr;
198 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
199 node->operation.attributes = std::move(attr);
200 const TfLiteAddParams* tf_options;
201 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
202 return MaybeFuseActivation(tf_options->activation, graph, node);
203 }
204 };
205
206 class BatchedMatMulOperationParser : public TFLiteOperationParser {
207 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)208 absl::Status IsSupported(const TfLiteContext* context,
209 const TfLiteNode* tflite_node,
210 const TfLiteRegistration* registration) final {
211 return CheckInputsOutputs(context, tflite_node,
212 /*runtime_inputs=*/2, /*outputs=*/1);
213 }
214
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)215 absl::Status Parse(const TfLiteNode* tflite_node,
216 const TfLiteRegistration* registration,
217 GraphFloat32* graph, ObjectReader* reader) final {
218 Node* node = graph->NewNode();
219 node->operation.type = ToString(OperationType::BATCHED_MATMUL);
220 RETURN_IF_ERROR(reader->AddInput(node, 0));
221 RETURN_IF_ERROR(reader->AddInput(node, 1));
222 RETURN_IF_ERROR(reader->AddOutputs(node));
223 return absl::OkStatus();
224 }
225 };
226
227 class ConcatenationOperationParser : public TFLiteOperationParser {
228 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)229 absl::Status IsSupported(const TfLiteContext* context,
230 const TfLiteNode* tflite_node,
231 const TfLiteRegistration* registration) final {
232 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
233
234 // TODO(eignasheva): add proper tensor availability checking
235 // for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
236 // RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, idx));
237 // }
238 // TODO(eignasheva): add axis checking.
239 const TfLiteConcatenationParams* tf_options;
240 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
241 return absl::OkStatus();
242 }
243
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)244 absl::Status Parse(const TfLiteNode* tflite_node,
245 const TfLiteRegistration* registration,
246 GraphFloat32* graph, ObjectReader* reader) final {
247 ConcatAttributes attr;
248 // Read inputs first to make sure const node is added to a graph before
249 // concat node to ensure topological order.
250 std::vector<const Value*> inputs;
251 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
252 Value* value;
253 const auto status = reader->ReadValue(idx, &value);
254 if (status.ok()) {
255 inputs.push_back(value);
256 } else {
257 TensorFloat32 tensor;
258 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
259 Value* value;
260 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
261 inputs.push_back(value);
262 }
263 }
264
265 Node* node = graph->NewNode();
266 node->operation.type = ToString(OperationType::CONCAT);
267 RETURN_IF_ERROR(reader->AddOutputs(node));
268 for (const Value* input : inputs) {
269 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
270 }
271
272 std::vector<BHWC> input_shapes;
273 for (auto input : graph->FindInputs(node->id)) {
274 input_shapes.push_back(input->tensor.shape);
275 }
276 RETURN_IF_ERROR(SetAxis(input_shapes, &attr.axis));
277
278 // Guess axis.
279 BHWC output_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
280 for (auto input : graph->FindInputs(node->id)) {
281 if (input->tensor.shape.h != output_shape.h) {
282 attr.axis = Axis::HEIGHT;
283 break;
284 }
285 if (input->tensor.shape.w != output_shape.w) {
286 attr.axis = Axis::WIDTH;
287 break;
288 }
289 if (input->tensor.shape.c != output_shape.c) {
290 attr.axis = Axis::CHANNELS;
291 break;
292 }
293 }
294 const TfLiteConcatenationParams* tf_options;
295 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
296 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
297 node->operation.attributes = attr;
298 return absl::OkStatus();
299 }
300
301 private:
SetAxis(const std::vector<BHWC> & input_shapes,Axis * axis)302 absl::Status SetAxis(const std::vector<BHWC>& input_shapes, Axis* axis) {
303 *axis = Axis::BATCH;
304 for (int i = 1; i < input_shapes.size(); i++) {
305 if (input_shapes[0].h != input_shapes[i].h &&
306 input_shapes[0].w != input_shapes[i].w &&
307 input_shapes[0].c != input_shapes[i].c) {
308 *axis = Axis::HEIGHT;
309 break;
310 }
311 }
312 if (*axis == Axis::BATCH) return absl::OkStatus();
313 for (int i = 1; i < input_shapes.size(); i++) {
314 if (input_shapes[0].b != input_shapes[i].b &&
315 input_shapes[0].w != input_shapes[i].w &&
316 input_shapes[0].c != input_shapes[i].c) {
317 *axis = Axis::WIDTH;
318 break;
319 }
320 }
321 if (*axis == Axis::HEIGHT) return absl::OkStatus();
322 for (int i = 1; i < input_shapes.size(); i++) {
323 if (input_shapes[0].b != input_shapes[i].b &&
324 input_shapes[0].h != input_shapes[i].h &&
325 input_shapes[0].c != input_shapes[i].c) {
326 *axis = Axis::CHANNELS;
327 break;
328 }
329 }
330 if (*axis == Axis::WIDTH) return absl::OkStatus();
331 for (int i = 1; i < input_shapes.size(); i++) {
332 if (input_shapes[0].b != input_shapes[i].b &&
333 input_shapes[0].w != input_shapes[i].w &&
334 input_shapes[0].h != input_shapes[i].h) {
335 return absl::UnimplementedError(
336 "Can concatenate tensors only by batch, height, width, or "
337 "channels.");
338 }
339 }
340 return absl::OkStatus();
341 }
342 };
343
344 class Conv2DOperationParser : public TFLiteOperationParser {
345 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)346 absl::Status IsSupported(const TfLiteContext* context,
347 const TfLiteNode* tflite_node,
348 const TfLiteRegistration* registration) final {
349 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 5));
350 const int runtime_inputs =
351 GetNumberOfRuntimeInputsForNode(context, tflite_node);
352 if (runtime_inputs > 2) {
353 return absl::InternalError(
354 absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
355 runtime_inputs, " runtime inputs."));
356 }
357 const int runtime_outputs = NumOutputs(tflite_node);
358 if (runtime_outputs != 1) {
359 return absl::InternalError(
360 absl::StrCat("Expected 1 output tensor(s), but node has ",
361 runtime_outputs, " runtime outputs."));
362 }
363 if (runtime_inputs == 1) {
364 RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
365 }
366 const TfLiteConvParams* tf_options;
367 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
368 RETURN_IF_ERROR(CheckStridesAndDilation(
369 tf_options->stride_height, tf_options->stride_width,
370 tf_options->dilation_height_factor, tf_options->dilation_width_factor));
371 return IsActivationSupported(tf_options->activation);
372 }
373
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)374 absl::Status Parse(const TfLiteNode* tflite_node,
375 const TfLiteRegistration* registration,
376 GraphFloat32* graph, ObjectReader* reader) final {
377 Node* node = graph->NewNode();
378 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
379 RETURN_IF_ERROR(reader->AddInput(node, 0));
380 RETURN_IF_ERROR(reader->AddOutputs(node));
381
382 Convolution2DAttributes attr;
383 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
384 if (runtime_inputs == 2) {
385 RETURN_IF_ERROR(reader->AddInput(node, 1));
386 } else { // runtime_inputs == 1;
387 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
388 }
389 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
390
391 const TfLiteConvParams* tf_options;
392 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
393 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
394 attr.dilations = HW(tf_options->dilation_height_factor,
395 tf_options->dilation_width_factor);
396 UpdatePadding(tf_options->padding,
397 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
398 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
399 node->operation.attributes = std::move(attr);
400 return absl::OkStatus();
401 }
402 };
403
404 class DepthwiseConvolutionOperationParser : public TFLiteOperationParser {
405 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)406 absl::Status IsSupported(const TfLiteContext* context,
407 const TfLiteNode* tflite_node,
408 const TfLiteRegistration* registration) final {
409 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 6));
410 const int runtime_inputs =
411 GetNumberOfRuntimeInputsForNode(context, tflite_node);
412 if (runtime_inputs > 2) {
413 return absl::InternalError(
414 absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
415 runtime_inputs, " runtime inputs."));
416 }
417 const int runtime_outputs = NumOutputs(tflite_node);
418 if (runtime_outputs != 1) {
419 return absl::InternalError(
420 absl::StrCat("Expected 1 output tensor(s), but node has ",
421 runtime_outputs, " runtime outputs."));
422 }
423 if (runtime_inputs == 1) {
424 RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
425 }
426 const TfLiteDepthwiseConvParams* tf_options;
427 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
428 RETURN_IF_ERROR(CheckStridesAndDilation(
429 tf_options->stride_height, tf_options->stride_width,
430 tf_options->dilation_height_factor, tf_options->dilation_width_factor));
431 RETURN_IF_ERROR(IsActivationSupported(tf_options->activation));
432
433 const int depth_multiplier = tf_options->depth_multiplier;
434 const auto* input = context->tensors + tflite_node->inputs->data[0];
435 const auto* filter = context->tensors + tflite_node->inputs->data[1];
436 const auto* bias = tflite_node->inputs->size > 2
437 ? context->tensors + tflite_node->inputs->data[2]
438 : nullptr;
439 const auto* output = context->tensors + tflite_node->outputs->data[0];
440 if (!input->dims || input->dims->size != 4) {
441 return absl::InvalidArgumentError("input.dims.size != 4");
442 }
443 if (!filter->dims || filter->dims->size != 4) {
444 return absl::InvalidArgumentError("filter.dims.size != 4");
445 }
446 if (!output->dims || output->dims->size != 4) {
447 return absl::InvalidArgumentError("output.dims.size != 4");
448 }
449 if (input->dims->data[0] != output->dims->data[0]) {
450 return absl::InvalidArgumentError("input.b != output.b");
451 }
452 const int input_depth = input->dims->data[3];
453 const int output_depth = output->dims->data[3];
454 if (filter->dims->data[3] != output_depth) {
455 return absl::InvalidArgumentError("filter.i != output.c");
456 }
457 if (output_depth != input_depth * depth_multiplier) {
458 return absl::InvalidArgumentError(
459 "output.c != input.c * depth_multiplier");
460 }
461 if (bias && NumElements(bias) != output_depth) {
462 return absl::InvalidArgumentError("bias.size != output.c");
463 }
464 if (depth_multiplier != 1 && input_depth != 1) {
465 return absl::UnimplementedError("depth_multiplier != 1 && input.c != 1");
466 }
467 return absl::OkStatus();
468 }
469
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)470 absl::Status Parse(const TfLiteNode* tflite_node,
471 const TfLiteRegistration* registration,
472 GraphFloat32* graph, ObjectReader* reader) final {
473 Node* node = graph->NewNode();
474 node->operation.type = ToString(OperationType::DEPTHWISE_CONVOLUTION);
475 RETURN_IF_ERROR(reader->AddInput(node, 0));
476 RETURN_IF_ERROR(reader->AddOutputs(node));
477
478 DepthwiseConvolution2DAttributes attr;
479 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
480 if (runtime_inputs == 2) {
481 RETURN_IF_ERROR(reader->AddInput(node, 1));
482 } else { // runtime_inputs == 1;
483 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
484 }
485 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
486 const TfLiteDepthwiseConvParams* tf_options;
487 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
488 attr.strides = ToHW(tf_options->stride_height, tf_options->stride_width);
489 attr.dilations = HW(std::max(1, tf_options->dilation_height_factor),
490 std::max(1, tf_options->dilation_width_factor));
491 UpdatePadding(tf_options->padding,
492 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
493 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
494 const int depth_multiplier = tf_options->depth_multiplier;
495 if (depth_multiplier != 1) {
496 const TfLiteTensor* input = reader->GetInputTensor(0);
497 const TfLiteTensor* filter = reader->GetInputTensor(1);
498 const TfLiteTensor* output = reader->GetOutputTensor(0);
499 TransposeWeights(input, filter, output, depth_multiplier, &attr);
500 }
501 node->operation.attributes = std::move(attr);
502 return absl::OkStatus();
503 }
504
505 private:
506 // TFLite CPU stores weights as:
507 // [1, kernel_height, kernel_width, input_depth * depth_multiplier]
508 // TFLite GPU stores weights as:
509 // [depth_multiplier, kernel_height, kernel_width, input_depth]
TransposeWeights(const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * output,int depth_multiplier,DepthwiseConvolution2DAttributes * attr)510 static void TransposeWeights(const TfLiteTensor* input,
511 const TfLiteTensor* filter,
512 const TfLiteTensor* output, int depth_multiplier,
513 DepthwiseConvolution2DAttributes* attr) {
514 const int input_depth = input->dims->data[3];
515 const int filter_height = filter->dims->data[1];
516 const int filter_width = filter->dims->data[2];
517 const int output_depth = output->dims->data[3];
518 Tensor<OHWI, DataType::FLOAT32> weights;
519 weights.id = attr->weights.id;
520 weights.shape =
521 OHWI(output_depth, filter_height, filter_width, input_depth);
522 weights.data.resize(weights.shape.DimensionsProduct());
523 float* dst = &weights.data[0];
524 for (int j = 0; j < output_depth; ++j) {
525 const float* src = attr->weights.data.data() + j;
526 for (int i = 0; i < filter_height * filter_width; ++i) {
527 *dst = *src;
528 dst++;
529 src += output_depth;
530 }
531 }
532 attr->weights = std::move(weights);
533 }
534 };
535
536 class DequantizeOperationParser : public TFLiteOperationParser {
537 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)538 absl::Status IsSupported(const TfLiteContext* context,
539 const TfLiteNode* tflite_node,
540 const TfLiteRegistration* registration) final {
541 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
542 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
543 /*runtime_inputs=*/1, /*outputs=*/1));
544 return absl::OkStatus();
545 }
546
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)547 absl::Status Parse(const TfLiteNode* tflite_node,
548 const TfLiteRegistration* registration,
549 GraphFloat32* graph, ObjectReader* reader) final {
550 // 'Dequantize' is rewritten as QuantizeAndDequantize since we are dealing
551 // with floating-point versions of the original tensors.
552 Node* node = graph->NewNode();
553 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
554 RETURN_IF_ERROR(reader->AddInput(node, 0));
555 RETURN_IF_ERROR(reader->AddOutputs(node));
556
557 // Quantization attributes should already be present in the input tensor.
558 auto input_value = graph->FindInputs(node->id)[0];
559 if (!input_value->quant_params) {
560 return absl::InvalidArgumentError(
561 "Encountered Dequantize input with no quant params");
562 }
563 QuantizeAndDequantizeAttributes attr;
564 attr.min = input_value->quant_params.value().min;
565 attr.max = input_value->quant_params.value().max;
566 attr.scale = input_value->quant_params.value().scale;
567
568 node->operation.attributes = attr;
569 return absl::OkStatus();
570 }
571 };
572
573 class ElementwiseOperationParser : public TFLiteOperationParser {
574 public:
ElementwiseOperationParser(OperationType operation_type)575 explicit ElementwiseOperationParser(OperationType operation_type)
576 : operation_type_(operation_type) {}
577
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)578 absl::Status IsSupported(const TfLiteContext* context,
579 const TfLiteNode* tflite_node,
580 const TfLiteRegistration* registration) final {
581 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
582 if (IsOneArgumentOperation()) {
583 RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
584 /*runtime_inputs=*/1,
585 /*const_inputs=*/0,
586 /*outputs=*/1));
587 // For some elementwise operations (currently only for SUB operation)
588 // second condition may be false. But it's worth checking the next case
589 // with const input, which may be supported.
590 } else if (IsTwoArgumentOperation() &&
591 CheckInputsConstsOutputs(context, tflite_node,
592 /*runtime_inputs=*/2,
593 /*const_inputs=*/0,
594 /*outputs=*/1)
595 .ok()) {
596 } else if (IsTwoArgumentOperationWithConst()) {
597 RETURN_IF_ERROR(CheckInputsConstsOutputs(context, tflite_node,
598 /*runtime_inputs=*/1,
599 /*const_inputs=*/1,
600 /*outputs=*/1));
601 } else {
602 return absl::InvalidArgumentError(
603 "Op can only handle 1 or 2 operand(s).");
604 }
605 TfLiteFusedActivation activation;
606 RETURN_IF_ERROR(GetActivation(tflite_node, &activation));
607 return IsActivationSupported(activation);
608 }
609
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)610 absl::Status Parse(const TfLiteNode* tflite_node,
611 const TfLiteRegistration* registration,
612 GraphFloat32* graph, ObjectReader* reader) final {
613 Node* node = graph->NewNode();
614 node->operation.type = ToString(operation_type_);
615
616 if (IsOneArgumentOperation()) {
617 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
618 /*runtime_inputs=*/1,
619 /*const_inputs=*/0,
620 /*outputs=*/1));
621
622 RETURN_IF_ERROR(reader->AddInput(node, 0));
623 } else if (IsTwoArgumentOperation() &&
624 reader
625 ->VerifyInputsConstsOutputs(tflite_node,
626 /*runtime_inputs=*/2,
627 /*const_inputs=*/0,
628 /*outputs=*/1)
629 .ok()) {
630 if (tflite_node->inputs->size != 2) {
631 return absl::InvalidArgumentError("Applies only two input tensors");
632 }
633 RETURN_IF_ERROR(reader->AddInput(node, 0));
634 RETURN_IF_ERROR(reader->AddInput(node, 1));
635
636 TfLiteFusedActivation activation = kTfLiteActNone;
637 switch (operation_type_) {
638 case OperationType::SUB: {
639 const TfLiteSubParams* tf_options;
640 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
641 activation = tf_options->activation;
642 }
643 break;
644 }
645 case OperationType::DIV: {
646 const TfLiteDivParams* tf_options;
647 if (RetrieveBuiltinData(tflite_node, &tf_options).ok()) {
648 activation = tf_options->activation;
649 }
650 break;
651 }
652 default:
653 // No activation expected.
654 activation = kTfLiteActNone;
655 }
656
657 if (activation) {
658 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
659 }
660 } else if (IsTwoArgumentOperationWithConst()) {
661 RETURN_IF_ERROR(reader->VerifyInputsConstsOutputs(tflite_node,
662 /*runtime_inputs=*/1,
663 /*const_inputs=*/1,
664 /*outputs=*/1));
665 ElementwiseAttributes attr;
666 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
667 attr.runtime_tensor_is_second =
668 IsConstantTensor(reader->GetInputTensor(0));
669 node->operation.attributes = std::move(attr);
670 } else {
671 return absl::InvalidArgumentError("Incorrect operation type passed");
672 }
673
674 return reader->AddOutputs(node);
675 }
676
677 private:
GetActivation(const TfLiteNode * tflite_node,TfLiteFusedActivation * activation) const678 absl::Status GetActivation(const TfLiteNode* tflite_node,
679 TfLiteFusedActivation* activation) const {
680 if (operation_type_ == OperationType::DIV) {
681 const TfLiteDivParams* tf_options;
682 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
683 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
684 return absl::OkStatus();
685 }
686 if (operation_type_ == OperationType::SUB) {
687 const TfLiteSubParams* tf_options;
688 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
689 *activation = status.ok() ? tf_options->activation : kTfLiteActNone;
690 return absl::OkStatus();
691 }
692
693 // Return kTfLiteActNone as other ops either do not have TfLiteXxxParams or
694 // TfLiteXxxParams.activation.
695 *activation = kTfLiteActNone;
696 return absl::OkStatus();
697 }
698
IsOneArgumentOperation() const699 bool IsOneArgumentOperation() const {
700 switch (operation_type_) {
701 case OperationType::ABS:
702 case OperationType::COPY:
703 case OperationType::COS:
704 case OperationType::ELU:
705 case OperationType::EXP:
706 case OperationType::LOG:
707 case OperationType::NEG:
708 case OperationType::RSQRT:
709 case OperationType::SIGMOID:
710 case OperationType::SIN:
711 case OperationType::SQRT:
712 case OperationType::SQUARE:
713 case OperationType::TANH:
714 return true;
715 default:
716 return false;
717 }
718 }
719
IsTwoArgumentOperation() const720 bool IsTwoArgumentOperation() const {
721 switch (operation_type_) {
722 case OperationType::DIV:
723 case OperationType::MAXIMUM:
724 case OperationType::MINIMUM:
725 case OperationType::POW:
726 case OperationType::SQUARED_DIFF:
727 case OperationType::SUB:
728 return true;
729 default:
730 return false;
731 }
732 }
733
IsTwoArgumentOperationWithConst() const734 bool IsTwoArgumentOperationWithConst() const {
735 switch (operation_type_) {
736 case OperationType::DIV:
737 case OperationType::MAXIMUM:
738 case OperationType::MINIMUM:
739 case OperationType::POW:
740 case OperationType::SQUARED_DIFF:
741 case OperationType::SUB:
742 return true;
743 default:
744 return false;
745 }
746 }
747
748 OperationType operation_type_;
749 };
750
751 class FullyConnectedOperationParser : public TFLiteOperationParser {
752 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)753 absl::Status IsSupported(const TfLiteContext* context,
754 const TfLiteNode* tflite_node,
755 const TfLiteRegistration* registration) final {
756 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 9));
757 const TfLiteFullyConnectedParams* tf_options;
758 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
759 if (tf_options->weights_format !=
760 kTfLiteFullyConnectedWeightsFormatDefault) {
761 return absl::UnimplementedError(
762 "Unsupported FullyConnected weights format.");
763 }
764 if (GetNumberOfRuntimeInputsForNode(context, tflite_node) > 2) {
765 return absl::UnimplementedError(
766 "FullyConnected doesn't support more than 2 runtime inputs.");
767 }
768 if (tf_options->keep_num_dims == true) {
769 const auto* input = context->tensors + tflite_node->inputs->data[0];
770 const auto* output = context->tensors + tflite_node->outputs->data[0];
771 if (input->dims->size != output->dims->size) {
772 return absl::UnimplementedError(
773 "Input and output dimensions different and FullyConnected doesn't "
774 "support keep_num_dims.");
775 }
776 }
777 // TODO(eignasheva): check input shape
778 return absl::OkStatus();
779 }
780
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)781 absl::Status Parse(const TfLiteNode* tflite_node,
782 const TfLiteRegistration* registration,
783 GraphFloat32* graph, ObjectReader* reader) final {
784 const TfLiteFullyConnectedParams* tf_options;
785 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
786
787 if (reader->GetNumberOfRuntimeInputs() == 2) {
788 // Create Convolution2D, so as it supports runtime weights.
789 Node* node = graph->NewNode();
790 node->operation.type = ToString(OperationType::CONVOLUTION_2D);
791 RETURN_IF_ERROR(reader->AddInput(node, 0));
792 RETURN_IF_ERROR(reader->AddInput(node, 1));
793 RETURN_IF_ERROR(reader->AddOutputs(node));
794
795 Convolution2DAttributes attr;
796 reader->ReadTensor(2, &attr.bias).IgnoreError(); // bias is optional
797
798 attr.strides = HW(1, 1);
799 attr.dilations = HW(1, 1);
800 attr.padding.appended = HW(0, 0);
801 attr.padding.prepended = HW(0, 0);
802 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
803 node->operation.attributes = std::move(attr);
804 return absl::OkStatus();
805 }
806 Node* node = graph->NewNode();
807 RETURN_IF_ERROR(reader->AddInput(node, 0));
808
809 if (tf_options->weights_format !=
810 kTfLiteFullyConnectedWeightsFormatDefault) {
811 return absl::UnimplementedError(
812 "Unsupported FullyConnected weights format.");
813 }
814
815 FullyConnectedAttributes attr;
816 RETURN_IF_ERROR(GetFullyConnectedAttributes(1, 2, reader, &attr));
817 const int weights_width = attr.weights.shape.i;
818
819 auto input = graph->FindInputs(node->id)[0];
820 int batch_size = input->tensor.shape.b;
821 if (input->tensor.shape.DimensionsProduct() / batch_size != weights_width) {
822 return absl::UnimplementedError(
823 "Amount of input data should match weights width");
824 }
825
826 Node* conv = node;
827 if (input->tensor.shape.h != 1 || input->tensor.shape.w != 1) {
828 auto& reshape = node;
829 conv = graph->NewNode(); // reset conv pointer!
830 Value* reshaped_value = graph->NewValue();
831 reshaped_value->tensor.type = DataType::FLOAT32;
832 reshaped_value->tensor.shape =
833 BHWC(input->tensor.shape.b, 1, 1, weights_width);
834 RETURN_IF_ERROR(graph->SetProducer(reshape->id, reshaped_value->id));
835 reshape->operation.type = ToString(OperationType::RESHAPE);
836 ReshapeAttributes attr;
837 attr.new_shape = reshaped_value->tensor.shape;
838 reshape->operation.attributes = attr;
839 RETURN_IF_ERROR(graph->AddConsumer(conv->id, reshaped_value->id));
840 }
841
842 conv->operation.type = ToString(OperationType::FULLY_CONNECTED);
843 conv->operation.attributes = std::move(attr);
844 absl::Status result = reader->AddOutputs(conv);
845 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, conv));
846
847 return result;
848 }
849 };
850
851 class HardSwishOperationParser : public TFLiteOperationParser {
852 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration *)853 absl::Status IsSupported(const TfLiteContext* context,
854 const TfLiteNode* tflite_node,
855 const TfLiteRegistration*) final {
856 return CheckInputsOutputs(context, tflite_node, /*runtime_inputs=*/1,
857 /*outputs=*/1);
858 }
859
Parse(const TfLiteNode *,const TfLiteRegistration *,GraphFloat32 * graph,ObjectReader * reader)860 absl::Status Parse(const TfLiteNode*, const TfLiteRegistration*,
861 GraphFloat32* graph, ObjectReader* reader) final {
862 Node* node = graph->NewNode();
863 node->operation.type = ToString(OperationType::HARD_SWISH);
864 RETURN_IF_ERROR(reader->AddInput(node, 0));
865 return reader->AddOutputs(node);
866 }
867 };
868
869 // Basic LSTM Cell:
870 //
871 // 1name = name is at input index 1
872 // name1 = name is at output index 1
873 //
874 // 0input 1prev_activ
875 // \ /
876 // [[concat]]
877 // \
878 // concat_temp2 2weights 3biases
879 // \ / /
880 // [[fully-connected]]
881 // \
882 // activ_temp3 4prev_state
883 // \ /
884 // [[LSTM]]
885 // / \
886 // new_state1 activation0
887 //
888 // For full LSTM cells, see this blog post:
889 // https://colah.github.io/posts/2015-08-Understanding-LSTMs/
890 // In addition to Peephole connections and Combined Input Forget Gates (CIFG)
891 // described in that post, this code also adds the following optional features:
892 // - Configurable activations (sigmoid or TANH)
893 // - L2 Normalization of gates: https://arxiv.org/abs/1607.06450
894 // - Output projection:
895 // https://www.isca-speech.org/archive/interspeech_2014/i14_0338.html
896 // - Configurable clipping of cell state and output state.
897 class LSTMOperationParser : public TFLiteOperationParser {
898 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)899 absl::Status IsSupported(const TfLiteContext* context,
900 const TfLiteNode* tflite_node,
901 const TfLiteRegistration* registration) final {
902 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 4));
903 const TfLiteLSTMParams* tf_options;
904 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
905 switch (tf_options->kernel_type) {
906 case kTfLiteLSTMFullKernel: {
907 const int inputs = NumInputs(tflite_node);
908 if (inputs != 20 && inputs != 24) {
909 return absl::InternalError(
910 absl::StrCat("Expected 20 or 24 input tensors, but node has ",
911 inputs, " input(s)."));
912 }
913 const int runtime_outputs = NumOutputs(tflite_node);
914 if (runtime_outputs != 1) {
915 return absl::InternalError(
916 absl::StrCat("Expected 1 output tensor, but node has ",
917 runtime_outputs, " output(s)."));
918 }
919 return CheckFullParameters(tf_options);
920 }
921 case kTfLiteLSTMBasicKernel:
922 RETURN_IF_ERROR(
923 CheckInputsConstsOutputs(context, tflite_node, /*runtime_inputs=*/3,
924 /*const_inputs=*/2, /*outputs=*/4));
925 return CheckBasicParameters(tf_options);
926 }
927 }
928
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)929 absl::Status Parse(const TfLiteNode* tflite_node,
930 const TfLiteRegistration* registration,
931 GraphFloat32* graph, ObjectReader* reader) final {
932 const TfLiteLSTMParams* tf_options;
933 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
934 switch (tf_options->kernel_type) {
935 case kTfLiteLSTMFullKernel:
936 return ParseFull(tflite_node, registration, graph, reader, tf_options);
937 case kTfLiteLSTMBasicKernel:
938 return ParseBasic(tflite_node, registration, graph, reader, tf_options);
939 }
940 }
941
GetNewValueIdsForVariableInputNodes()942 absl::flat_hash_map<int, ValueId> GetNewValueIdsForVariableInputNodes()
943 final {
944 return new_variable_input_value_map_;
945 }
946
947 private:
ParseBasic(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)948 absl::Status ParseBasic(const TfLiteNode* tflite_node,
949 const TfLiteRegistration* registration,
950 GraphFloat32* graph, ObjectReader* reader,
951 const TfLiteLSTMParams* tf_options) {
952 if (tflite_node->inputs->size != 5) {
953 return absl::InvalidArgumentError("LSTM should have 5 input tensors");
954 }
955 if (tflite_node->outputs->size != 4) {
956 return absl::InvalidArgumentError("LSTM should have 4 output tensors");
957 }
958 RETURN_IF_ERROR(CheckBasicParameters(tf_options));
959
960 Node* concat_node = graph->NewNode();
961 concat_node->operation.type = ToString(OperationType::CONCAT);
962 ConcatAttributes concat_attr;
963 concat_attr.axis = Axis::CHANNELS;
964 concat_node->operation.attributes = concat_attr;
965
966 Node* fc_node = graph->NewNode();
967 fc_node->operation.type = ToString(OperationType::FULLY_CONNECTED);
968 FullyConnectedAttributes fc_attr;
969 RETURN_IF_ERROR(GetFullyConnectedAttributes(2, 3, reader, &fc_attr));
970 fc_node->operation.attributes = std::move(fc_attr);
971
972 Node* lstm_node = graph->NewNode();
973 lstm_node->operation.type = ToString(OperationType::LSTM);
974 LstmAttributes lstm_attr;
975 lstm_attr.kernel_type = LstmKernelType::BASIC;
976 lstm_node->operation.attributes = lstm_attr;
977
978 Value* concat_temp;
979 int concat_tensor_idx = tflite_node->outputs->data[2];
980 RETURN_IF_ERROR(
981 reader->ReadValueByTensorIdx(concat_tensor_idx, &concat_temp));
982 Value* activ_temp;
983 int activ_tensor_idx = tflite_node->outputs->data[3];
984 RETURN_IF_ERROR(
985 reader->ReadValueByTensorIdx(activ_tensor_idx, &activ_temp));
986
987 RETURN_IF_ERROR(reader->AddInput(concat_node, 0)); // input
988 RETURN_IF_ERROR(reader->AddInput(concat_node, 1)); // prev_activ
989 RETURN_IF_ERROR(graph->SetProducer(concat_node->id, concat_temp->id));
990
991 RETURN_IF_ERROR(graph->AddConsumer(fc_node->id, concat_temp->id));
992 RETURN_IF_ERROR(graph->SetProducer(fc_node->id, activ_temp->id));
993
994 RETURN_IF_ERROR(graph->AddConsumer(lstm_node->id, activ_temp->id));
995 RETURN_IF_ERROR(reader->AddInput(lstm_node, 4)); // prev_state
996 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 1)); // new_state
997 RETURN_IF_ERROR(reader->AddOutput(lstm_node, 0)); // activation
998
999 return absl::OkStatus();
1000 }
1001
CheckBasicParameters(const TfLiteLSTMParams * tf_options)1002 absl::Status CheckBasicParameters(const TfLiteLSTMParams* tf_options) {
1003 if (tf_options->activation != kTfLiteActTanh) {
1004 return absl::UnimplementedError("Only TANH activation is supported.");
1005 }
1006 if (tf_options->cell_clip != 0.0f) {
1007 return absl::UnimplementedError("cell_clip is not supported.");
1008 }
1009 if (tf_options->proj_clip != 0.0f) {
1010 return absl::UnimplementedError("proj_clip is not supported.");
1011 }
1012 return absl::OkStatus();
1013 }
1014
ParseFull(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * tf_options)1015 absl::Status ParseFull(const TfLiteNode* tflite_node,
1016 const TfLiteRegistration* registration,
1017 GraphFloat32* graph, ObjectReader* reader,
1018 const TfLiteLSTMParams* tf_options) {
1019 // Invoke full LSTM parser
1020 RETURN_IF_ERROR(ParseLSTMAttributes(tflite_node, registration, graph,
1021 reader, tf_options,
1022 &new_variable_input_value_map_));
1023 return absl::OkStatus();
1024 }
1025
CheckFullParameters(const TfLiteLSTMParams * tf_options)1026 absl::Status CheckFullParameters(const TfLiteLSTMParams* tf_options) {
1027 if (tf_options->activation != kTfLiteActSigmoid &&
1028 tf_options->activation != kTfLiteActTanh) {
1029 return absl::UnimplementedError(
1030 "Only sigmoid or tanh activation is supported.");
1031 }
1032
1033 return absl::OkStatus();
1034 }
1035
1036 absl::flat_hash_map<int, ValueId> new_variable_input_value_map_;
1037 };
1038
1039 class MulOperationParser : public TFLiteOperationParser {
1040 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1041 absl::Status IsSupported(const TfLiteContext* context,
1042 const TfLiteNode* tflite_node,
1043 const TfLiteRegistration* registration) final {
1044 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1045 if (tflite_node->inputs->size != 2) {
1046 return absl::UnimplementedError("MUL requires two input tensors.");
1047 }
1048 const TfLiteTensor* input0 = GetInput(context, tflite_node, 0);
1049 const TfLiteTensor* input1 = GetInput(context, tflite_node, 1);
1050 if (input0 == nullptr || input1 == nullptr) {
1051 return absl::InvalidArgumentError("At least one input tensor is null");
1052 }
1053 if (input0->dims->size == input1->dims->size) {
1054 // this code checks that at least one input of Mul not smaller in all
1055 // dimensions. Sometimes Mul used for matrix-vector multiplication that we
1056 // currently don't support. For example input0 HWC(1, 256, 1), input1
1057 // HWC(1, 1, 256) -> output HWC (1, 256, 256). In this case it can be
1058 // replaced with Convolution operation.
1059 bool first_has_smaller_dim = false;
1060 bool second_has_smaller_dim = false;
1061 for (int i = 0; i < input0->dims->size; ++i) {
1062 if (input0->dims->data[i] < input1->dims->data[i]) {
1063 first_has_smaller_dim = true;
1064 }
1065 if (input1->dims->data[i] < input0->dims->data[i]) {
1066 second_has_smaller_dim = true;
1067 }
1068 }
1069 if (first_has_smaller_dim && second_has_smaller_dim) {
1070 return absl::UnimplementedError(
1071 "MUL requires one tensor that not less than second in all "
1072 "dimensions.");
1073 }
1074 }
1075 const TfLiteMulParams* tf_options;
1076 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1077 return IsActivationSupported(tf_options->activation);
1078 }
1079
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1080 absl::Status Parse(const TfLiteNode* tflite_node,
1081 const TfLiteRegistration* registration,
1082 GraphFloat32* graph, ObjectReader* reader) final {
1083 const TfLiteTensor* input0 = reader->GetInputTensor(0);
1084 if (!input0) {
1085 return absl::InvalidArgumentError(
1086 "Couldn't get the 1st input tensor for MUL.");
1087 }
1088 const TfLiteTensor* input1 = reader->GetInputTensor(1);
1089 if (!input1) {
1090 return absl::InvalidArgumentError(
1091 "Couldn't get the 2nd input tensor for MUL.");
1092 }
1093 const bool constant_tensor0 = IsConstantTensor(input0);
1094 const bool constant_tensor1 = IsConstantTensor(input1);
1095 if (constant_tensor0 && constant_tensor1) {
1096 return absl::InvalidArgumentError("No runtime input tensors for MUL.");
1097 }
1098 const bool runtime_tensor0 = !constant_tensor0;
1099 const bool runtime_tensor1 = !constant_tensor1;
1100
1101 Node* node = graph->NewNode();
1102 node->operation.type = ToString(OperationType::MUL);
1103 RETURN_IF_ERROR(reader->AddOutputs(node));
1104
1105 // Determine runtime/constant tensors.
1106 if (runtime_tensor0 && runtime_tensor1) {
1107 if (input0 == input1) {
1108 // replace MUL(A, A) with POW(A, 2.0)
1109 // TODO(b/166831113): Support the same inputs for operations.
1110 node->operation.type = ToString(OperationType::POW);
1111 ElementwiseAttributes attr;
1112 attr.param = 2.0f;
1113 node->operation.attributes = std::move(attr);
1114 return reader->AddInput(node, 0);
1115 }
1116
1117 // The "larger" input tensor must be bound to 1st input and the "smaller"
1118 // input tensor must be bound to 2nd input.
1119 BHWC shape0;
1120 RETURN_IF_ERROR(ExtractTensorShape(*input0, &shape0));
1121 BHWC shape1;
1122 RETURN_IF_ERROR(ExtractTensorShape(*input1, &shape1));
1123 int input_tensor0 = 0;
1124 int input_tensor1 = 1;
1125 if (shape0.h <= shape1.h && shape0.w <= shape1.w &&
1126 shape0.c == shape1.c) {
1127 input_tensor0 = 1;
1128 input_tensor1 = 0;
1129 }
1130 RETURN_IF_ERROR(reader->AddInput(node, input_tensor0));
1131 RETURN_IF_ERROR(reader->AddInput(node, input_tensor1));
1132 } else {
1133 ElementwiseAttributes attr;
1134 RETURN_IF_ERROR(ParseInputsWithConstTensor(node, reader, &attr.param));
1135 node->operation.attributes = std::move(attr);
1136 }
1137
1138 const TfLiteMulParams* tf_options;
1139 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1140 return MaybeFuseActivation(tf_options->activation, graph, node);
1141 }
1142 };
1143
1144 class PackOperationParser : public TFLiteOperationParser {
1145 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1146 absl::Status IsSupported(const TfLiteContext* context,
1147 const TfLiteNode* tflite_node,
1148 const TfLiteRegistration* registration) final {
1149 const TfLitePackParams* tf_options;
1150 return RetrieveBuiltinData(tflite_node, &tf_options);
1151 }
1152
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1153 absl::Status Parse(const TfLiteNode* tflite_node,
1154 const TfLiteRegistration* registration,
1155 GraphFloat32* graph, ObjectReader* reader) final {
1156 if (tflite_node->inputs->size == 1) {
1157 // Pack with single input can be replaced with Reshape
1158 Node* node = graph->NewNode();
1159 node->operation.type = ToString(OperationType::RESHAPE);
1160 RETURN_IF_ERROR(reader->AddInput(node, 0));
1161 RETURN_IF_ERROR(reader->AddOutputs(node));
1162 // New shape comes from output shape.
1163 ReshapeAttributes attr;
1164 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1165 node->operation.attributes = attr;
1166 return absl::OkStatus();
1167 } else {
1168 // Pack with few inputs can be replaced with Concat
1169 const TfLitePackParams* tf_options;
1170 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1171
1172 // Read inputs first to make sure const node is added to a graph before
1173 // concat node to ensure topological order.
1174 std::vector<const Value*> inputs;
1175 for (uint32_t idx = 0; idx < tflite_node->inputs->size; ++idx) {
1176 Value* value;
1177 const auto status = reader->ReadValue(idx, &value);
1178 if (status.ok()) {
1179 inputs.push_back(value);
1180 } else {
1181 TensorFloat32 tensor;
1182 RETURN_IF_ERROR(reader->ReadTensor(idx, &tensor));
1183 Value* value;
1184 RETURN_IF_ERROR(NewConstNode(std::move(tensor), graph, &value));
1185 inputs.push_back(value);
1186 }
1187 }
1188
1189 Node* node = graph->NewNode();
1190 node->operation.type = ToString(OperationType::CONCAT);
1191 RETURN_IF_ERROR(reader->AddOutputs(node));
1192 for (const Value* input : inputs) {
1193 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1194 }
1195 const TfLiteTensor* output = reader->GetOutputTensor(0);
1196 ConcatAttributes attr;
1197 RETURN_IF_ERROR(
1198 ExtractAxisFromIndex(*output, tf_options->axis, &attr.axis));
1199 node->operation.attributes = attr;
1200 return absl::OkStatus();
1201 }
1202 }
1203 };
1204
1205 class PReLUOperationParser : public TFLiteOperationParser {
1206 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1207 absl::Status IsSupported(const TfLiteContext* context,
1208 const TfLiteNode* tflite_node,
1209 const TfLiteRegistration* registration) final {
1210 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1211 // TODO(eignasheva): add params check
1212 return absl::OkStatus();
1213 }
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1214 absl::Status Parse(const TfLiteNode* tflite_node,
1215 const TfLiteRegistration* registration,
1216 GraphFloat32* graph, ObjectReader* reader) final {
1217 Node* node = graph->NewNode();
1218 node->operation.type = ToString(OperationType::PRELU);
1219 RETURN_IF_ERROR(reader->AddInput(node, 0));
1220 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1221
1222 PReLUAttributes attr;
1223 Tensor<Linear, DataType::FLOAT32> linear_alpha;
1224 absl::Status status = reader->ReadTensor(1, &linear_alpha);
1225 if (status.ok()) {
1226 if (linear_alpha.shape.v != input_shape.c) {
1227 return absl::InvalidArgumentError(
1228 "Linear alpha shape does not match the number of input channels.");
1229 }
1230 attr.alpha = std::move(linear_alpha);
1231 } else {
1232 Tensor<HWC, DataType::FLOAT32> hwc_alpha;
1233 RETURN_IF_ERROR(reader->ReadTensor(1, &hwc_alpha));
1234 if (hwc_alpha.shape.h != input_shape.h ||
1235 hwc_alpha.shape.w != input_shape.w ||
1236 hwc_alpha.shape.c != input_shape.c) {
1237 return absl::InvalidArgumentError(
1238 "Alpha shape does not match input shape.");
1239 }
1240 attr.alpha = std::move(hwc_alpha);
1241 }
1242 node->operation.attributes = std::move(attr);
1243 return reader->AddOutputs(node);
1244 }
1245 };
1246
1247 class PadOperationParser : public TFLiteOperationParser {
1248 public:
PadOperationParser(bool mirror_pad)1249 explicit PadOperationParser(bool mirror_pad) : mirror_pad_(mirror_pad) {}
1250
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1251 absl::Status IsSupported(const TfLiteContext* context,
1252 const TfLiteNode* tflite_node,
1253 const TfLiteRegistration* registration) final {
1254 if (mirror_pad_) {
1255 const TfLiteMirrorPaddingParams* tf_options;
1256 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1257 if (tf_options->mode !=
1258 TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect) {
1259 return absl::InvalidArgumentError(
1260 "Only Reflective padding is supported for Mirror Pad operation.");
1261 }
1262 }
1263 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1264 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1265 /*runtime_inputs=*/1, /*outputs=*/1));
1266 RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
1267 const TfLiteTensor* pad_tensor = GetInput(context, tflite_node, 1);
1268 if (pad_tensor == nullptr) {
1269 return absl::InvalidArgumentError("Padding tensor was null");
1270 }
1271 if (pad_tensor->dims->size != 2) {
1272 return absl::InvalidArgumentError(absl::StrCat(
1273 "Invalid paddings tensor dimension: expected 2 dim, got ",
1274 pad_tensor->dims->size, " dim"));
1275 }
1276 bool supported =
1277 pad_tensor->dims->data[0] == 3 || pad_tensor->dims->data[0] == 4;
1278 if (!supported || pad_tensor->dims->data[1] != 2) {
1279 return absl::InvalidArgumentError(absl::StrCat(
1280 "Invalid paddings tensor shape: expected 4x2 or 3x2, got ",
1281 pad_tensor->dims->data[0], "x", pad_tensor->dims->data[1]));
1282 }
1283 return absl::OkStatus();
1284 }
1285
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1286 absl::Status Parse(const TfLiteNode* tflite_node,
1287 const TfLiteRegistration* registration,
1288 GraphFloat32* graph, ObjectReader* reader) final {
1289 Node* node = graph->NewNode();
1290 node->operation.type = ToString(OperationType::PAD);
1291 RETURN_IF_ERROR(reader->AddInput(node, 0));
1292 RETURN_IF_ERROR(reader->AddOutputs(node));
1293
1294 PadAttributes attr;
1295 if (mirror_pad_) {
1296 attr.type = PaddingContentType::REFLECT;
1297 } else /*zero pad*/ {
1298 attr.type = PaddingContentType::ZEROS;
1299 }
1300
1301 Tensor<HW, DataType::INT32> paddings;
1302 RETURN_IF_ERROR(reader->ReadTensor(1, &paddings));
1303
1304 if (paddings.shape.h == 4 && paddings.shape.w == 2) {
1305 // 4x2 tensor with paddings.
1306 attr.prepended = BHWC(paddings.data[0], paddings.data[2],
1307 paddings.data[4], paddings.data[6]);
1308 attr.appended = BHWC(paddings.data[1], paddings.data[3], paddings.data[5],
1309 paddings.data[7]);
1310 } else if (paddings.shape.h == 3 && paddings.shape.w == 2) {
1311 // 3x2 tensor with paddings.
1312 attr.prepended =
1313 BHWC(1, paddings.data[0], paddings.data[2], paddings.data[4]);
1314 attr.appended =
1315 BHWC(1, paddings.data[1], paddings.data[3], paddings.data[5]);
1316 } else {
1317 // It shouldn't fail here since it's checked at IsSupported().
1318 return absl::InvalidArgumentError(
1319 "Paddings tensor has unexpected shape.");
1320 }
1321 node->operation.attributes = attr;
1322 return absl::OkStatus();
1323 }
1324
1325 private:
1326 bool mirror_pad_ = false;
1327 };
1328
1329 class Pooling2DOperationParser : public TFLiteOperationParser {
1330 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1331 absl::Status IsSupported(const TfLiteContext* context,
1332 const TfLiteNode* tflite_node,
1333 const TfLiteRegistration* registration) final {
1334 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1335 const TfLitePoolParams* tf_options;
1336 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1337 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1338 /*runtime_inputs=*/1,
1339 /*outputs=*/1));
1340 RETURN_IF_ERROR(CheckKernelsAndStrides(
1341 tf_options->filter_height, tf_options->filter_width,
1342 tf_options->stride_height, tf_options->stride_width));
1343 return IsActivationSupported(tf_options->activation);
1344 }
1345
1346 public:
Pooling2DOperationParser(PoolingType type)1347 explicit Pooling2DOperationParser(PoolingType type) : type_(type) {}
1348
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1349 absl::Status Parse(const TfLiteNode* tflite_node,
1350 const TfLiteRegistration* registration,
1351 GraphFloat32* graph, ObjectReader* reader) final {
1352 Node* node = graph->NewNode();
1353 node->operation.type = ToString(OperationType::POOLING_2D);
1354 RETURN_IF_ERROR(reader->AddInput(node, 0));
1355 RETURN_IF_ERROR(reader->AddOutput(node, 0));
1356
1357 Pooling2DAttributes attr;
1358 attr.type = type_;
1359
1360 auto input_shape = graph->FindInputs(node->id)[0]->tensor.shape;
1361
1362 const TfLitePoolParams* tf_options;
1363 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1364
1365 RETURN_IF_ERROR(MaybeFuseActivation(tf_options->activation, graph, node));
1366
1367 attr.output_indices = false;
1368 RETURN_IF_ERROR(ParsePoolingAttributes(tf_options, input_shape, &attr));
1369 node->operation.attributes = attr;
1370 return absl::OkStatus();
1371 }
1372
1373 private:
1374 const PoolingType type_;
1375 };
1376
1377 class ReduceOperationParser : public TFLiteOperationParser {
1378 public:
ReduceOperationParser(OperationType operation_type)1379 explicit ReduceOperationParser(OperationType operation_type)
1380 : operation_type_(operation_type) {}
1381
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1382 absl::Status IsSupported(const TfLiteContext* context,
1383 const TfLiteNode* tflite_node,
1384 const TfLiteRegistration* registration) final {
1385 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1386 /*runtime_inputs=*/1, /*outputs=*/1));
1387 auto* axes = &context->tensors[tflite_node->inputs->data[1]];
1388 if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
1389 return absl::UnimplementedError(
1390 "Reduce has unsupported tensor for axes.");
1391 }
1392 return absl::OkStatus();
1393 }
1394
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1395 absl::Status Parse(const TfLiteNode* tflite_node,
1396 const TfLiteRegistration* registration,
1397 GraphFloat32* graph, ObjectReader* reader) final {
1398 Node* node = graph->NewNode();
1399 node->operation.type = ToString(operation_type_);
1400 RETURN_IF_ERROR(reader->AddInput(node, 0));
1401 RETURN_IF_ERROR(reader->AddOutputs(node));
1402
1403 const TfLiteReducerParams* tf_options;
1404 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1405
1406 ReduceAttributes attr;
1407 const TfLiteTensor* input = reader->GetInputTensor(0);
1408 const TfLiteTensor* axes = reader->GetInputTensor(1);
1409 for (int i = 0; i < NumElements(axes->dims); i++) {
1410 Axis axis;
1411 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
1412 attr.dims.insert(axis);
1413 }
1414 node->operation.attributes = attr;
1415 return absl::OkStatus();
1416 }
1417
1418 private:
1419 const OperationType operation_type_;
1420 };
1421
1422 class QuantizeOperationParser : public TFLiteOperationParser {
1423 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1424 absl::Status IsSupported(const TfLiteContext* context,
1425 const TfLiteNode* tflite_node,
1426 const TfLiteRegistration* registration) final {
1427 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1428 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1429 /*runtime_inputs=*/1, /*outputs=*/1));
1430 return absl::OkStatus();
1431 }
1432
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1433 absl::Status Parse(const TfLiteNode* tflite_node,
1434 const TfLiteRegistration* registration,
1435 GraphFloat32* graph, ObjectReader* reader) final {
1436 // 'Quantize' is rewritten as QuantizeAndDequantize since we are dealing
1437 // with floating-point versions of the original tensors.
1438 Node* node = graph->NewNode();
1439 node->operation.type = ToString(OperationType::QUANTIZE_AND_DEQUANTIZE);
1440 RETURN_IF_ERROR(reader->AddInput(node, 0));
1441 RETURN_IF_ERROR(reader->AddOutputs(node));
1442
1443 // Quantization attributes should already be present in the output tensor.
1444 auto output_value = graph->FindOutputs(node->id)[0];
1445 if (!output_value->quant_params) {
1446 return absl::InvalidArgumentError(
1447 "Encountered Quantize output with no quant params");
1448 }
1449 QuantizeAndDequantizeAttributes attr;
1450 attr.min = output_value->quant_params.value().min;
1451 attr.max = output_value->quant_params.value().max;
1452 attr.scale = output_value->quant_params.value().scale;
1453
1454 node->operation.attributes = attr;
1455 return absl::OkStatus();
1456 }
1457 };
1458
1459 class ReLUOperationParser : public TFLiteOperationParser {
1460 public:
ReLUOperationParser(int clip)1461 explicit ReLUOperationParser(int clip) : clip_(clip) {}
1462
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1463 absl::Status IsSupported(const TfLiteContext* context,
1464 const TfLiteNode* tflite_node,
1465 const TfLiteRegistration* registration) final {
1466 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1467 return absl::OkStatus();
1468 }
1469
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1470 absl::Status Parse(const TfLiteNode* tflite_node,
1471 const TfLiteRegistration* registration,
1472 GraphFloat32* graph, ObjectReader* reader) final {
1473 Node* node = graph->NewNode();
1474 node->operation.type = ToString(OperationType::RELU);
1475 RETURN_IF_ERROR(reader->AddInput(node, 0));
1476
1477 ReLUAttributes attr;
1478 const TfLiteLeakyReluParams* tf_options;
1479 auto status = RetrieveBuiltinData(tflite_node, &tf_options);
1480 attr.alpha = status.ok() ? tf_options->alpha : 0;
1481 attr.clip = clip_;
1482 node->operation.attributes = attr;
1483 return reader->AddOutputs(node);
1484 }
1485
1486 private:
1487 const int clip_;
1488 };
1489
1490 class ReshapeOperationParser : public TFLiteOperationParser {
1491 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1492 absl::Status IsSupported(const TfLiteContext* context,
1493 const TfLiteNode* tflite_node,
1494 const TfLiteRegistration* registration) final {
1495 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 1));
1496 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1497 /*runtime_inputs=*/1, /*outputs=*/1));
1498 // TODO(eignasheva): add shape checking
1499 return absl::OkStatus();
1500 }
1501
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1502 absl::Status Parse(const TfLiteNode* tflite_node,
1503 const TfLiteRegistration* registration,
1504 GraphFloat32* graph, ObjectReader* reader) final {
1505 Node* node = graph->NewNode();
1506 node->operation.type = ToString(OperationType::RESHAPE);
1507 RETURN_IF_ERROR(reader->AddInput(node, 0));
1508 RETURN_IF_ERROR(reader->AddOutputs(node));
1509 // Here we may have extra inputs. Other tensors were supposed to
1510 // define new shape, but in TFLite these are ignored.
1511 // TODO(akulik): check that shapes match?
1512
1513 // New shape comes from output shape.
1514 ReshapeAttributes attr;
1515 attr.new_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1516 node->operation.attributes = attr;
1517 return absl::OkStatus();
1518 }
1519 };
1520
1521 class Resize2DOperationParser : public TFLiteOperationParser {
1522 public:
Resize2DOperationParser(SamplingType sampling_type)1523 explicit Resize2DOperationParser(SamplingType sampling_type)
1524 : sampling_type_(sampling_type) {}
1525
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1526 absl::Status IsSupported(const TfLiteContext* context,
1527 const TfLiteNode* tflite_node,
1528 const TfLiteRegistration* registration) final {
1529 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
1530 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1531 /*runtime_inputs=*/1, /*outputs=*/1));
1532
1533 bool align_corners;
1534 RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &align_corners));
1535 bool half_pixel_centers;
1536 RETURN_IF_ERROR(GetHalfPixelCentersValue(tflite_node, &half_pixel_centers));
1537 return absl::OkStatus();
1538 }
1539
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1540 absl::Status Parse(const TfLiteNode* tflite_node,
1541 const TfLiteRegistration* registration,
1542 GraphFloat32* graph, ObjectReader* reader) final {
1543 Node* node = graph->NewNode();
1544 node->operation.type = ToString(OperationType::RESIZE);
1545 RETURN_IF_ERROR(reader->AddInput(node, 0));
1546 RETURN_IF_ERROR(reader->AddOutputs(node));
1547 // Here we may have extra inputs. Other tensors were supposed to
1548 // define new shape, but in TFLite these are ignored.
1549
1550 Resize2DAttributes attr;
1551 RETURN_IF_ERROR(GetAlignCornersValue(tflite_node, &attr.align_corners));
1552 RETURN_IF_ERROR(
1553 GetHalfPixelCentersValue(tflite_node, &attr.half_pixel_centers));
1554 attr.type = sampling_type_;
1555 attr.new_shape.CopyAllDefinedAxis(
1556 graph->FindOutputs(node->id)[0]->tensor.shape);
1557 node->operation.attributes = attr;
1558 return absl::OkStatus();
1559 }
1560
1561 private:
GetAlignCornersValue(const TfLiteNode * tflite_node,bool * align_corners)1562 absl::Status GetAlignCornersValue(const TfLiteNode* tflite_node,
1563 bool* align_corners) {
1564 switch (sampling_type_) {
1565 case SamplingType::BILINEAR:
1566 return GetAlignCornersValueForType<TfLiteResizeBilinearParams>(
1567 tflite_node, align_corners);
1568 case SamplingType::NEAREST:
1569 return GetAlignCornersValueForType<TfLiteResizeNearestNeighborParams>(
1570 tflite_node, align_corners);
1571 case SamplingType::UNKNOWN:
1572 return absl::InternalError("Sampling type is not specified");
1573 }
1574 return absl::OkStatus();
1575 }
1576
1577 template <class T>
GetAlignCornersValueForType(const TfLiteNode * tflite_node,bool * align_corners)1578 absl::Status GetAlignCornersValueForType(const TfLiteNode* tflite_node,
1579 bool* align_corners) {
1580 const T* tf_options;
1581 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1582 *align_corners = tf_options->align_corners;
1583 return absl::OkStatus();
1584 }
1585
GetHalfPixelCentersValue(const TfLiteNode * tflite_node,bool * half_pixel_centers)1586 absl::Status GetHalfPixelCentersValue(const TfLiteNode* tflite_node,
1587 bool* half_pixel_centers) {
1588 if (sampling_type_ == SamplingType::BILINEAR) {
1589 const TfLiteResizeBilinearParams* tf_options;
1590 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1591 if (tf_options->align_corners && tf_options->half_pixel_centers) {
1592 return absl::InternalError(
1593 "If half_pixel_centers is True, align_corners must be False.");
1594 }
1595 *half_pixel_centers = tf_options->half_pixel_centers;
1596 } else {
1597 const TfLiteResizeNearestNeighborParams* tf_options;
1598 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1599 *half_pixel_centers = tf_options->half_pixel_centers;
1600 }
1601 return absl::OkStatus();
1602 }
1603
1604 SamplingType sampling_type_ = SamplingType::UNKNOWN;
1605 };
1606
1607 class SliceOperationParser : public TFLiteOperationParser {
1608 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1609 absl::Status IsSupported(const TfLiteContext* context,
1610 const TfLiteNode* tflite_node,
1611 const TfLiteRegistration* registration) final {
1612 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1613 if (tflite_node->inputs->size < 3) {
1614 return absl::UnimplementedError("SLICE requires 3 inputs.");
1615 }
1616 const TfLiteTensor* input = GetInput(context, tflite_node, 0);
1617 if (input->dims->size != 3 && input->dims->size != 4) {
1618 return absl::UnimplementedError(
1619 "SLICE supports for 3 or 4 dimensional tensors only.");
1620 }
1621
1622 return absl::OkStatus();
1623 }
1624
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1625 absl::Status Parse(const TfLiteNode* tflite_node,
1626 const TfLiteRegistration* registration,
1627 GraphFloat32* graph, ObjectReader* reader) final {
1628 Node* node = graph->NewNode();
1629 node->operation.type = ToString(OperationType::SLICE);
1630 RETURN_IF_ERROR(reader->AddOutputs(node));
1631 Value* input;
1632 RETURN_IF_ERROR(reader->ReadValue(0, &input));
1633 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1634
1635 const TfLiteTensor* tfl_input = reader->GetInputTensor(0);
1636 const int input_dims = tfl_input->dims->size;
1637
1638 SliceAttributes attr;
1639 attr.strides = BHWC(1, 1, 1, 1);
1640 Tensor<Linear, DataType::INT32> starts, sizes;
1641 RETURN_IF_ERROR(reader->ReadTensor(1, &starts));
1642 RETURN_IF_ERROR(reader->ReadTensor(2, &sizes));
1643 if (starts.data.size() != sizes.data.size()) {
1644 return absl::InvalidArgumentError("Starts amount != sizes amount.");
1645 }
1646 BHWC bhwc_starts(0, 0, 0, 0);
1647 BHWC bhwc_sizes = input->tensor.shape;
1648 if (input_dims == 4) {
1649 // input in BHWC layout
1650 if (starts.data.size() == 4) {
1651 bhwc_starts.b = starts.data[0];
1652 bhwc_starts.h = starts.data[1];
1653 bhwc_starts.w = starts.data[2];
1654 bhwc_starts.c = starts.data[3];
1655 bhwc_sizes.b = sizes.data[0];
1656 bhwc_sizes.h = sizes.data[1];
1657 bhwc_sizes.w = sizes.data[2];
1658 bhwc_sizes.c = sizes.data[3];
1659 } else if (starts.data.size() == 3) {
1660 // if input is 4D(BHWC) and args 3D, we assume that args in HWC layout
1661 bhwc_starts.h = starts.data[0];
1662 bhwc_starts.w = starts.data[1];
1663 bhwc_starts.c = starts.data[2];
1664 bhwc_sizes.h = sizes.data[0];
1665 bhwc_sizes.w = sizes.data[1];
1666 bhwc_sizes.c = sizes.data[2];
1667 } else {
1668 return absl::UnimplementedError(
1669 "Slicing is supported for 3 or 4 dimensional tensors only.");
1670 }
1671 } else if (input_dims == 3) {
1672 // input in BWC layout
1673 if (starts.data.size() == 3) {
1674 bhwc_starts.b = starts.data[0];
1675 bhwc_starts.w = starts.data[1];
1676 bhwc_starts.c = starts.data[2];
1677 bhwc_sizes.b = sizes.data[0];
1678 bhwc_sizes.w = sizes.data[1];
1679 bhwc_sizes.c = sizes.data[2];
1680 } else {
1681 return absl::UnimplementedError(
1682 "Slicing is supported for 3 or 4 dimensional tensors only.");
1683 }
1684 } else {
1685 return absl::UnimplementedError(
1686 "Slicing is supported for 3 or 4 dimensional tensors only.");
1687 }
1688 const auto& in_shape = input->tensor.shape;
1689 if (bhwc_sizes.b == -1) {
1690 bhwc_sizes.b = in_shape.b - bhwc_starts.b;
1691 }
1692 if (bhwc_sizes.h == -1) {
1693 bhwc_sizes.h = in_shape.h - bhwc_starts.h;
1694 }
1695 if (bhwc_sizes.w == -1) {
1696 bhwc_sizes.w = in_shape.w - bhwc_starts.w;
1697 }
1698 if (bhwc_sizes.c == -1) {
1699 bhwc_sizes.c = in_shape.c - bhwc_starts.c;
1700 }
1701 attr.starts = bhwc_starts;
1702 attr.ends =
1703 BHWC(bhwc_starts.b + bhwc_sizes.b, bhwc_starts.h + bhwc_sizes.h,
1704 bhwc_starts.w + bhwc_sizes.w, bhwc_starts.c + bhwc_sizes.c);
1705 RETURN_IF_ERROR(UpdateIfNegative(in_shape, &attr));
1706
1707 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1708 if ((attr.ends.b - attr.starts.b) != out_shape.b) {
1709 return absl::UnimplementedError("Output batch don't match");
1710 }
1711 if ((attr.ends.h - attr.starts.h) != out_shape.h) {
1712 return absl::UnimplementedError("Output height doesn't match");
1713 }
1714 if ((attr.ends.w - attr.starts.w) != out_shape.w) {
1715 return absl::UnimplementedError("Output width doesn't match");
1716 }
1717 if ((attr.ends.c - attr.starts.c) != out_shape.c) {
1718 return absl::UnimplementedError("Output channels don't match");
1719 }
1720 node->operation.attributes = attr;
1721 return absl::OkStatus();
1722 }
1723
1724 private:
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1725 absl::Status UpdateIfNegative(const BHWC& input_shape,
1726 SliceAttributes* attr) {
1727 if (attr->ends.h < 0) {
1728 attr->ends.h = input_shape.h + attr->ends.h;
1729 }
1730 if (attr->ends.w < 0) {
1731 attr->ends.w = input_shape.w + attr->ends.w;
1732 }
1733 if (attr->ends.c < 0) {
1734 attr->ends.c = input_shape.c + attr->ends.c;
1735 }
1736 if (attr->ends.b < 0) {
1737 attr->ends.b = input_shape.b + attr->ends.b;
1738 }
1739 return absl::OkStatus();
1740 }
1741 };
1742
1743 class SoftmaxOperationParser : public TFLiteOperationParser {
1744 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1745 absl::Status IsSupported(const TfLiteContext* context,
1746 const TfLiteNode* tflite_node,
1747 const TfLiteRegistration* registration) final {
1748 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1749 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1750 /*runtime_inputs=*/1, /*outputs=*/1));
1751 const TfLiteSoftmaxParams* tf_options;
1752 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1753 if (tf_options->beta != 1) {
1754 // TODO(eignasheva): figure out, what's wrong with softmax.
1755 return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1756 }
1757 return absl::OkStatus();
1758 }
1759
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1760 absl::Status Parse(const TfLiteNode* tflite_node,
1761 const TfLiteRegistration* registration,
1762 GraphFloat32* graph, ObjectReader* reader) final {
1763 Node* node = graph->NewNode();
1764 node->operation.type = ToString(OperationType::SOFTMAX);
1765 RETURN_IF_ERROR(reader->AddInput(node, 0));
1766 RETURN_IF_ERROR(reader->AddOutputs(node));
1767
1768 const TfLiteSoftmaxParams* tf_options;
1769 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1770 if (tf_options->beta != 1) {
1771 // there is multiply by scalar operation fused in softmax. Make a layer
1772 // out of it before softmax.
1773 return absl::UnimplementedError("Softmax.beta != 1 is not supported.");
1774 // auto mul_node = reader->NewPassthroughNode(node);
1775 // mul_node->operation.type = ToString(OperationType::MUL);
1776 }
1777 SoftmaxAttributes attr;
1778 attr.axis = Axis::CHANNELS; // always by channels
1779 node->operation.attributes = attr;
1780 return absl::OkStatus();
1781 }
1782 };
1783
1784 class SpaceToDepthOperationParser : public TFLiteOperationParser {
1785 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1786 absl::Status IsSupported(const TfLiteContext* context,
1787 const TfLiteNode* tflite_node,
1788 const TfLiteRegistration* registration) final {
1789 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1790 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
1791 /*runtime_inputs=*/1, /*outputs=*/1));
1792 // TODO(impjdi): Dims check.
1793 const TfLiteSpaceToDepthParams* s2d_params;
1794 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &s2d_params));
1795 if (s2d_params->block_size == 1) {
1796 return absl::InvalidArgumentError(
1797 "SPACE_TO_DEPTH block_size = 1 is a no-op.");
1798 }
1799 if (s2d_params->block_size < 1) {
1800 return absl::InvalidArgumentError(
1801 "SPACE_TO_DEPTH block_size must be > 1.");
1802 }
1803 return absl::OkStatus();
1804 }
1805
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1806 absl::Status Parse(const TfLiteNode* tflite_node,
1807 const TfLiteRegistration* registration,
1808 GraphFloat32* graph, ObjectReader* reader) final {
1809 Node* node = graph->NewNode();
1810 node->operation.type = ToString(OperationType::SPACE_TO_DEPTH);
1811 RETURN_IF_ERROR(reader->AddInput(node, 0));
1812 RETURN_IF_ERROR(reader->AddOutputs(node));
1813 const TfLiteSpaceToDepthParams* tf_options;
1814 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1815 SpaceToDepthAttributes attr;
1816 attr.block_size = tf_options->block_size;
1817 node->operation.attributes = attr;
1818 return absl::OkStatus();
1819 }
1820 };
1821
1822 class SplitVOperationParser : public TFLiteOperationParser {
1823 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1824 absl::Status IsSupported(const TfLiteContext* context,
1825 const TfLiteNode* tflite_node,
1826 const TfLiteRegistration* registration) final {
1827 const TfLiteSplitVParams* split_params;
1828 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &split_params));
1829 if (split_params->num_splits == 1) {
1830 return absl::InvalidArgumentError(
1831 "SplitV with num_splits = 1 is a no-op.");
1832 }
1833 return absl::OkStatus();
1834 }
1835
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1836 absl::Status Parse(const TfLiteNode* tflite_node,
1837 const TfLiteRegistration* registration,
1838 GraphFloat32* graph, ObjectReader* reader) final {
1839 const TfLiteTensor* input = reader->GetInputTensor(0);
1840 const TfLiteTensor* axis_tensor = reader->GetInputTensor(2);
1841 SplitAttributes attr;
1842 RETURN_IF_ERROR(
1843 ExtractAxisFromIndex(*input, axis_tensor->data.i32[0], &attr.axis));
1844
1845 Node* node = graph->NewNode();
1846 node->operation.type = ToString(OperationType::SPLIT);
1847 node->operation.attributes = attr;
1848 RETURN_IF_ERROR(reader->AddInput(node, 0));
1849 for (int i = 0; i < tflite_node->outputs->size; ++i) {
1850 RETURN_IF_ERROR(reader->AddOutput(node, i));
1851 }
1852 return absl::OkStatus();
1853 }
1854 };
1855
1856 class StridedSliceOperationParser : public TFLiteOperationParser {
1857 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)1858 absl::Status IsSupported(const TfLiteContext* context,
1859 const TfLiteNode* tflite_node,
1860 const TfLiteRegistration* registration) final {
1861 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
1862 const TfLiteStridedSliceParams* tf_options;
1863 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1864 RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1865
1866 if (tflite_node->inputs->size < 4) {
1867 return absl::UnimplementedError("STRIDED_SLICE requires 4 inputs.");
1868 }
1869 const TfLiteTensor* input = GetInput(context, tflite_node, 0);
1870 if (input->dims->size != 3 && input->dims->size != 4) {
1871 return absl::UnimplementedError(
1872 "STRIDED_SLICE supports for 3 or 4 dimensional tensors only.");
1873 }
1874 return absl::OkStatus();
1875 }
1876
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)1877 absl::Status Parse(const TfLiteNode* tflite_node,
1878 const TfLiteRegistration* registration,
1879 GraphFloat32* graph, ObjectReader* reader) final {
1880 Node* node = graph->NewNode();
1881 node->operation.type = ToString(OperationType::SLICE);
1882 RETURN_IF_ERROR(reader->AddOutputs(node));
1883 Value* input;
1884 RETURN_IF_ERROR(reader->ReadValue(0, &input));
1885 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
1886
1887 Tensor<Linear, DataType::INT32> tmp;
1888 RETURN_IF_ERROR(reader->ReadTensor(1, &tmp));
1889
1890 bool read_without_batch = tmp.data.size() == 3;
1891 bool read_with_batch = tmp.data.size() == 4;
1892 if (!read_without_batch && !read_with_batch) {
1893 // Error: Must be catched in IsSupported()
1894 return absl::UnimplementedError(
1895 "Slicing is supported for 3 or 4 dimensional tensors only.");
1896 }
1897
1898 const TfLiteStridedSliceParams* tf_options;
1899 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
1900 RETURN_IF_ERROR(CheckOptionsSupport(tf_options));
1901
1902 auto out_shape = graph->FindOutputs(node->id)[0]->tensor.shape;
1903
1904 SliceAttributes attr;
1905 if (read_without_batch) {
1906 RETURN_IF_ERROR(ReadAttribsWithoutBatch(reader, tf_options,
1907 input->tensor.shape, &attr));
1908 }
1909 if (read_with_batch) {
1910 RETURN_IF_ERROR(
1911 ReadAttribsWithBatch(reader, tf_options, input->tensor.shape, &attr));
1912 }
1913 if (attr.strides.b == 0 || attr.strides.h == 0 || attr.strides.w == 0 ||
1914 attr.strides.c == 0) {
1915 return absl::InvalidArgumentError("stride values must be non-zero");
1916 }
1917 if (attr.strides.b < 0 || attr.strides.h < 0 || attr.strides.w < 0 ||
1918 attr.strides.c < 0) {
1919 return absl::UnimplementedError("Reverse slices are not supported.");
1920 }
1921 if ((attr.ends.b - attr.starts.b + attr.strides.b - 1) / attr.strides.b !=
1922 out_shape.b) {
1923 return absl::UnimplementedError("Output batch don't match");
1924 }
1925 if ((attr.ends.h - attr.starts.h + attr.strides.h - 1) / attr.strides.h !=
1926 out_shape.h) {
1927 return absl::UnimplementedError("Output height doesn't match");
1928 }
1929 if ((attr.ends.w - attr.starts.w + attr.strides.w - 1) / attr.strides.w !=
1930 out_shape.w) {
1931 return absl::UnimplementedError("Output width doesn't match");
1932 }
1933 if ((attr.ends.c - attr.starts.c + attr.strides.c - 1) / attr.strides.c !=
1934 out_shape.c) {
1935 return absl::UnimplementedError("Output channels don't match");
1936 }
1937 node->operation.attributes = attr;
1938 return absl::OkStatus();
1939 }
1940
1941 private:
UpdateWithMask(const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,int ignore_b,int ignore_h,int ignore_w,int ignore_c,SliceAttributes * attr)1942 absl::Status UpdateWithMask(const TfLiteStridedSliceParams* tf_options,
1943 const BHWC& input_shape, int ignore_b,
1944 int ignore_h, int ignore_w, int ignore_c,
1945 SliceAttributes* attr) {
1946 if (tf_options->begin_mask & ignore_h) {
1947 attr->starts.h = 0;
1948 }
1949 if (tf_options->begin_mask & ignore_w) {
1950 attr->starts.w = 0;
1951 }
1952 if (tf_options->begin_mask & ignore_c) {
1953 attr->starts.c = 0;
1954 }
1955 if (tf_options->begin_mask & ignore_b) {
1956 attr->starts.b = 0;
1957 }
1958
1959 if (tf_options->end_mask & ignore_h) {
1960 attr->ends.h = input_shape.h;
1961 }
1962 if (tf_options->end_mask & ignore_w) {
1963 attr->ends.w = input_shape.w;
1964 }
1965 if (tf_options->end_mask & ignore_c) {
1966 attr->ends.c = input_shape.c;
1967 }
1968 if (tf_options->end_mask & ignore_b) {
1969 attr->ends.b = input_shape.b;
1970 }
1971 return absl::OkStatus();
1972 }
1973
UpdateIfNegative(const BHWC & input_shape,SliceAttributes * attr)1974 absl::Status UpdateIfNegative(const BHWC& input_shape,
1975 SliceAttributes* attr) {
1976 if (attr->ends.h < 0) {
1977 attr->ends.h = input_shape.h + attr->ends.h;
1978 }
1979 if (attr->ends.w < 0) {
1980 attr->ends.w = input_shape.w + attr->ends.w;
1981 }
1982 if (attr->ends.c < 0) {
1983 attr->ends.c = input_shape.c + attr->ends.c;
1984 }
1985 if (attr->ends.b < 0) {
1986 attr->ends.b = input_shape.b + attr->ends.b;
1987 }
1988
1989 if (attr->starts.h < 0) {
1990 attr->starts.h = input_shape.h + attr->starts.h;
1991 }
1992 if (attr->starts.w < 0) {
1993 attr->starts.w = input_shape.w + attr->starts.w;
1994 }
1995 if (attr->starts.c < 0) {
1996 attr->starts.c = input_shape.c + attr->starts.c;
1997 }
1998 if (attr->starts.b < 0) {
1999 attr->starts.b = input_shape.b + attr->starts.b;
2000 }
2001
2002 return absl::OkStatus();
2003 }
2004
ReadAttribsWithBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2005 absl::Status ReadAttribsWithBatch(const ObjectReader* reader,
2006 const TfLiteStridedSliceParams* tf_options,
2007 const BHWC& input_shape,
2008 SliceAttributes* attr) {
2009 auto read_bhwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2010 Tensor<Linear, DataType::INT32> t;
2011 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2012 *bhwc = BHWC(t.data[0], t.data[1], t.data[2], t.data[3]);
2013 return absl::OkStatus();
2014 };
2015
2016 RETURN_IF_ERROR(read_bhwc(1, &attr->starts));
2017 RETURN_IF_ERROR(read_bhwc(2, &attr->ends));
2018 RETURN_IF_ERROR(read_bhwc(3, &attr->strides));
2019 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2020 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 1, 2, 4, 8, attr));
2021 return absl::OkStatus();
2022 }
2023
ReadAttribsWithoutBatch(const ObjectReader * reader,const TfLiteStridedSliceParams * tf_options,const BHWC & input_shape,SliceAttributes * attr)2024 absl::Status ReadAttribsWithoutBatch(
2025 const ObjectReader* reader, const TfLiteStridedSliceParams* tf_options,
2026 const BHWC& input_shape, SliceAttributes* attr) {
2027 auto read_hwc = [&](int tensor_index, BHWC* bhwc) -> absl::Status {
2028 Tensor<Linear, DataType::INT32> t;
2029 RETURN_IF_ERROR(reader->ReadTensor(tensor_index, &t));
2030 *bhwc = BHWC(0, t.data[0], t.data[1], t.data[2]);
2031 return absl::OkStatus();
2032 };
2033
2034 RETURN_IF_ERROR(read_hwc(1, &attr->starts));
2035 RETURN_IF_ERROR(read_hwc(2, &attr->ends));
2036 RETURN_IF_ERROR(read_hwc(3, &attr->strides));
2037 RETURN_IF_ERROR(UpdateIfNegative(input_shape, attr));
2038 RETURN_IF_ERROR(UpdateWithMask(tf_options, input_shape, 0, 1, 2, 4, attr));
2039 attr->starts.b = 0;
2040 attr->ends.b = input_shape.b;
2041 attr->strides.b = 1;
2042 return absl::OkStatus();
2043 }
CheckOptionsSupport(const TfLiteStridedSliceParams * tf_options)2044 absl::Status CheckOptionsSupport(const TfLiteStridedSliceParams* tf_options) {
2045 if (tf_options->ellipsis_mask) {
2046 return absl::UnimplementedError("Slice does not support ellipsis_mask.");
2047 }
2048 if (tf_options->new_axis_mask) {
2049 return absl::UnimplementedError("Slice does not support new_axis_mask.");
2050 }
2051 if (tf_options->shrink_axis_mask) {
2052 return absl::UnimplementedError(
2053 "Slice does not support shrink_axis_mask parameter. ");
2054 }
2055 return absl::OkStatus();
2056 }
2057 };
2058
2059 // Builtin op version of TRANSPOSE_CONV.
2060 class TransposeConvBuiltinOperationParser : public TFLiteOperationParser {
2061 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2062 absl::Status IsSupported(const TfLiteContext* context,
2063 const TfLiteNode* tflite_node,
2064 const TfLiteRegistration* registration) final {
2065 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 3));
2066 const int runtime_inputs =
2067 GetNumberOfRuntimeInputsForNode(context, tflite_node);
2068 if (runtime_inputs > 2) {
2069 return absl::InternalError(
2070 absl::StrCat("Expected 1 or 2 input tensor(s), but node has ",
2071 runtime_inputs, " runtime inputs."));
2072 }
2073 const int runtime_outputs = NumOutputs(tflite_node);
2074 if (runtime_outputs != 1) {
2075 return absl::InternalError(
2076 absl::StrCat("Expected 1 output tensor(s), but node has ",
2077 runtime_outputs, " runtime outputs."));
2078 }
2079 if (runtime_inputs == 1) {
2080 RETURN_IF_ERROR(CheckTensorIsAvailable(context, tflite_node, 1));
2081 }
2082 const TfLiteTransposeConvParams* tf_options;
2083 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2084 RETURN_IF_ERROR(
2085 CheckStrides(tf_options->stride_height, tf_options->stride_width));
2086 return absl::OkStatus();
2087 }
2088
2089 // TFLite's TRANSPOSE_CONV expects 3-4 input tensors (output shape, weights,
2090 // input, and an optional bias) and allows configurable padding & stride.
2091 // TODO(impjdi): Translate output_shape to attr.adjacent.
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2092 absl::Status Parse(const TfLiteNode* tflite_node,
2093 const TfLiteRegistration* registration,
2094 GraphFloat32* graph, ObjectReader* reader) final {
2095 auto* node = graph->NewNode();
2096 node->operation.type = ToString(OperationType::CONVOLUTION_TRANSPOSED);
2097 Value* input;
2098 RETURN_IF_ERROR(reader->ReadValue(2, &input));
2099 RETURN_IF_ERROR(graph->AddConsumer(node->id, input->id));
2100 RETURN_IF_ERROR(reader->AddOutputs(node));
2101
2102 const TfLiteTransposeConvParams* tf_options;
2103 RETURN_IF_ERROR(RetrieveBuiltinData(tflite_node, &tf_options));
2104
2105 ConvolutionTransposedAttributes attr;
2106 attr.stride = tf_options
2107 ? HW(tf_options->stride_height, tf_options->stride_width)
2108 : HW(1, 1);
2109 const int runtime_inputs = reader->GetNumberOfRuntimeInputs();
2110 if (runtime_inputs == 2) {
2111 RETURN_IF_ERROR(reader->AddInput(node, 1));
2112 auto weights_shape = graph->FindInputs(node->id)[1]->tensor.shape;
2113 attr.weights.shape = OHWI(weights_shape.b, weights_shape.h,
2114 weights_shape.w, weights_shape.c);
2115 } else { // runtime_inputs == 1;
2116 RETURN_IF_ERROR(reader->ReadTensor(1, &attr.weights));
2117 }
2118 reader->ReadTensor(3, &attr.bias).IgnoreError(); // bias is optional
2119
2120 UpdatePadding(tf_options->padding,
2121 graph->FindInputs(node->id)[0]->tensor.shape, &attr);
2122 node->operation.attributes = std::move(attr);
2123 return absl::OkStatus();
2124 }
2125 };
2126
2127 class TransposeOperationParser : public TFLiteOperationParser {
2128 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2129 absl::Status IsSupported(const TfLiteContext* context,
2130 const TfLiteNode* tflite_node,
2131 const TfLiteRegistration* registration) final {
2132 RETURN_IF_ERROR(CheckMaxSupportedOpVersion(registration, 2));
2133 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
2134 /*runtime_inputs=*/1, /*outputs=*/1));
2135 return absl::OkStatus();
2136 }
2137
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2138 absl::Status Parse(const TfLiteNode* tflite_node,
2139 const TfLiteRegistration* registration,
2140 GraphFloat32* graph, ObjectReader* reader) final {
2141 Node* node = graph->NewNode();
2142 node->operation.type = ToString(OperationType::TRANSPOSE);
2143 RETURN_IF_ERROR(reader->AddInput(node, 0));
2144 RETURN_IF_ERROR(reader->AddOutputs(node));
2145
2146 TransposeAttributes attr;
2147 Tensor<Linear, DataType::INT32> perm;
2148 RETURN_IF_ERROR(reader->ReadTensor(1, &perm));
2149 std::map<Axis, int> axis_to_index = {{Axis::BATCH, 0},
2150 {Axis::HEIGHT, 1},
2151 {Axis::WIDTH, 2},
2152 {Axis::CHANNELS, 3}};
2153 if (perm.data.size() == 4) {
2154 attr.perm = BHWC(perm.data[0], perm.data[1], perm.data[2], perm.data[3]);
2155 } else if (perm.data.size() == 3) {
2156 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::WIDTH,
2157 Axis::CHANNELS};
2158 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2159 attr.perm.h = 1;
2160 attr.perm.w = axis_to_index[index_to_axis[perm.data[1]]];
2161 attr.perm.c = axis_to_index[index_to_axis[perm.data[2]]];
2162 } else if (perm.data.size() == 2) {
2163 std::vector<Axis> index_to_axis = {Axis::BATCH, Axis::CHANNELS};
2164 attr.perm.b = axis_to_index[index_to_axis[perm.data[0]]];
2165 attr.perm.h = 1;
2166 attr.perm.w = 2;
2167 attr.perm.c = axis_to_index[index_to_axis[perm.data[1]]];
2168 } else {
2169 return absl::InvalidArgumentError(
2170 "Permutation for transpose is invalid.");
2171 }
2172
2173 node->operation.attributes = attr;
2174 return absl::OkStatus();
2175 }
2176 };
2177
2178 // TODO(impjdi): BATCH_TO_SPACE/SPACE_TO_BATCH shouldn't be supported.
2179 class BatchToSpaceOperationParser : public TFLiteOperationParser {
2180 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2181 absl::Status IsSupported(const TfLiteContext* context,
2182 const TfLiteNode* tflite_node,
2183 const TfLiteRegistration* registration) final {
2184 return absl::OkStatus();
2185 }
2186
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2187 absl::Status Parse(const TfLiteNode* tflite_node,
2188 const TfLiteRegistration* registration,
2189 GraphFloat32* graph, ObjectReader* reader) final {
2190 auto* node = graph->NewNode();
2191 node->operation.type = ToString(OperationType::BATCH_TO_SPACE);
2192 RETURN_IF_ERROR(reader->AddInput(node, 0));
2193 RETURN_IF_ERROR(reader->AddOutputs(node));
2194
2195 BatchToSpaceAttributes bs_attr;
2196 Tensor<Linear, DataType::INT32> block;
2197 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2198 if (block.shape.v != 2) {
2199 return absl::InternalError("Space has to be HxW.");
2200 }
2201 bs_attr.block.h = block.data[0];
2202 bs_attr.block.w = block.data[1];
2203
2204 Tensor<HW, DataType::INT32> crop;
2205 RETURN_IF_ERROR(reader->ReadTensor(2, &crop));
2206 auto crop_shape = crop.shape;
2207 if (crop_shape.h != 2 && crop_shape.w != 2) {
2208 return absl::InternalError("Space has to be HxW.");
2209 }
2210
2211 bs_attr.crop.prepended.h = crop.data[0];
2212 bs_attr.crop.prepended.w = crop.data[2];
2213
2214 bs_attr.crop.appended.h = crop.data[1];
2215 bs_attr.crop.appended.w = crop.data[3];
2216
2217 node->operation.attributes = std::move(bs_attr);
2218 return absl::OkStatus();
2219 }
2220 };
2221
2222 class SpaceToBatchOperationParser : public TFLiteOperationParser {
2223 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2224 absl::Status IsSupported(const TfLiteContext* context,
2225 const TfLiteNode* tflite_node,
2226 const TfLiteRegistration* registration) final {
2227 return absl::OkStatus();
2228 }
2229
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2230 absl::Status Parse(const TfLiteNode* tflite_node,
2231 const TfLiteRegistration* registration,
2232 GraphFloat32* graph, ObjectReader* reader) final {
2233 auto* node = graph->NewNode();
2234 node->operation.type = ToString(OperationType::SPACE_TO_BATCH);
2235 RETURN_IF_ERROR(reader->AddInput(node, 0));
2236 RETURN_IF_ERROR(reader->AddOutputs(node));
2237 SpaceToBatchAttributes sb_attr;
2238 Tensor<Linear, DataType::INT32> block;
2239 RETURN_IF_ERROR(reader->ReadTensor(1, &block));
2240 if (block.shape.v != 2) {
2241 return absl::InternalError("Space has to be HxW.");
2242 }
2243 sb_attr.block.h = block.data[0];
2244 sb_attr.block.w = block.data[1];
2245
2246 Tensor<HW, DataType::INT32> padding;
2247 RETURN_IF_ERROR(reader->ReadTensor(2, &padding));
2248 auto padding_shape = padding.shape;
2249
2250 if (padding_shape.h != 2 && padding_shape.w != 2) {
2251 return absl::InternalError("Space has to be HxW.");
2252 }
2253
2254 sb_attr.padding.prepended.h = padding.data[0];
2255 sb_attr.padding.prepended.w = padding.data[2];
2256
2257 sb_attr.padding.appended.h = padding.data[1];
2258 sb_attr.padding.appended.w = padding.data[3];
2259
2260 node->operation.attributes = std::move(sb_attr);
2261 return absl::OkStatus();
2262 }
2263 };
2264
2265 class MeanOperationParser : public TFLiteOperationParser {
2266 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2267 absl::Status IsSupported(const TfLiteContext* context,
2268 const TfLiteNode* tflite_node,
2269 const TfLiteRegistration* registration) final {
2270 RETURN_IF_ERROR(CheckInputsOutputs(context, tflite_node,
2271 /*runtime_inputs=*/1,
2272 /*outputs=*/1));
2273
2274 auto* axes = &context->tensors[tflite_node->inputs->data[1]];
2275 if (axes->allocation_type != kTfLiteMmapRo || axes->type != kTfLiteInt32) {
2276 return absl::UnimplementedError("Mean has unsupported tensor for axes");
2277 }
2278 return absl::OkStatus();
2279 }
2280
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2281 absl::Status Parse(const TfLiteNode* tflite_node,
2282 const TfLiteRegistration* registration,
2283 GraphFloat32* graph, ObjectReader* reader) final {
2284 auto* node = graph->NewNode();
2285 node->operation.type = ToString(OperationType::MEAN);
2286 RETURN_IF_ERROR(reader->AddInput(node, 0));
2287 RETURN_IF_ERROR(reader->AddOutputs(node));
2288
2289 MeanAttributes attr;
2290 const TfLiteTensor* input = reader->GetInputTensor(0);
2291 const TfLiteTensor* axes = reader->GetInputTensor(1);
2292 for (int i = 0; i < NumElements(axes->dims); i++) {
2293 Axis axis;
2294 RETURN_IF_ERROR(ExtractAxisFromIndex(*input, axes->data.i32[i], &axis));
2295 attr.dims.insert(axis);
2296 }
2297 node->operation.attributes = attr;
2298 return absl::OkStatus();
2299 }
2300 };
2301
2302 class UnsupportedOperationParser : public TFLiteOperationParser {
2303 public:
IsSupported(const TfLiteContext * context,const TfLiteNode * tflite_node,const TfLiteRegistration * registration)2304 absl::Status IsSupported(const TfLiteContext* context,
2305 const TfLiteNode* tflite_node,
2306 const TfLiteRegistration* registration) final {
2307 return absl::UnimplementedError("Operation is not supported.");
2308 }
2309
Parse(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader)2310 absl::Status Parse(const TfLiteNode* tflite_node,
2311 const TfLiteRegistration* registration,
2312 GraphFloat32* graph, ObjectReader* reader) final {
2313 return absl::UnimplementedError("Operation is not supported.");
2314 }
2315 };
2316
NewOperationParser(const TfLiteRegistration * registration,bool allow_quant_ops=false)2317 std::unique_ptr<TFLiteOperationParser> NewOperationParser(
2318 const TfLiteRegistration* registration, bool allow_quant_ops = false) {
2319 const auto builtin_code = registration->builtin_code;
2320 switch (builtin_code) {
2321 case kTfLiteBuiltinAbs:
2322 return std::make_unique<ElementwiseOperationParser>(OperationType::ABS);
2323 case kTfLiteBuiltinAdd:
2324 return std::make_unique<AddOperationParser>();
2325 case kTfLiteBuiltinAveragePool2d:
2326 return std::make_unique<Pooling2DOperationParser>(PoolingType::AVERAGE);
2327 case kTfLiteBuiltinBatchMatmul:
2328 return std::make_unique<BatchedMatMulOperationParser>();
2329 case kTfLiteBuiltinConcatenation:
2330 return std::make_unique<ConcatenationOperationParser>();
2331 case kTfLiteBuiltinConv2d:
2332 return std::make_unique<Conv2DOperationParser>();
2333 case kTfLiteBuiltinCos:
2334 return std::make_unique<ElementwiseOperationParser>(OperationType::COS);
2335 case kTfLiteBuiltinDepthwiseConv2d:
2336 return std::make_unique<DepthwiseConvolutionOperationParser>();
2337 case kTfLiteBuiltinDequantize:
2338 if (allow_quant_ops) {
2339 return std::make_unique<DequantizeOperationParser>();
2340 }
2341 break;
2342 case kTfLiteBuiltinDiv:
2343 return std::make_unique<ElementwiseOperationParser>(OperationType::DIV);
2344 case kTfLiteBuiltinElu:
2345 return std::make_unique<ElementwiseOperationParser>(OperationType::ELU);
2346 case kTfLiteBuiltinExp:
2347 return std::make_unique<ElementwiseOperationParser>(OperationType::EXP);
2348 case kTfLiteBuiltinFullyConnected:
2349 return std::make_unique<FullyConnectedOperationParser>();
2350 case kTfLiteBuiltinHardSwish:
2351 return std::make_unique<HardSwishOperationParser>();
2352 case kTfLiteBuiltinLogistic:
2353 return std::make_unique<ElementwiseOperationParser>(
2354 OperationType::SIGMOID);
2355 case kTfLiteBuiltinLog:
2356 return std::make_unique<ElementwiseOperationParser>(OperationType::LOG);
2357 case kTfLiteBuiltinLstm:
2358 return std::make_unique<LSTMOperationParser>();
2359 case kTfLiteBuiltinMaximum:
2360 return std::make_unique<ElementwiseOperationParser>(
2361 OperationType::MAXIMUM);
2362 case kTfLiteBuiltinMaxPool2d:
2363 return std::make_unique<Pooling2DOperationParser>(PoolingType::MAX);
2364 case kTfLiteBuiltinMean:
2365 return std::make_unique<MeanOperationParser>();
2366 case kTfLiteBuiltinMinimum:
2367 return std::make_unique<ElementwiseOperationParser>(
2368 OperationType::MINIMUM);
2369 case kTfLiteBuiltinMirrorPad:
2370 return std::make_unique<PadOperationParser>(/*mirror_pad=*/true);
2371 case kTfLiteBuiltinMul:
2372 return std::make_unique<MulOperationParser>();
2373 case kTfLiteBuiltinNeg:
2374 return std::make_unique<ElementwiseOperationParser>(OperationType::NEG);
2375 case kTfLiteBuiltinPack:
2376 return std::make_unique<PackOperationParser>();
2377 case kTfLiteBuiltinPad:
2378 return std::make_unique<PadOperationParser>(/*mirror_pad=*/false);
2379 case kTfLiteBuiltinPow:
2380 return std::make_unique<ElementwiseOperationParser>(OperationType::POW);
2381 case kTfLiteBuiltinReduceMax:
2382 return std::make_unique<ReduceOperationParser>(
2383 OperationType::REDUCE_MAXIMUM);
2384 case kTfLiteBuiltinReduceMin:
2385 return std::make_unique<ReduceOperationParser>(
2386 OperationType::REDUCE_MINIMUM);
2387 case kTfLiteBuiltinReduceProd:
2388 return std::make_unique<ReduceOperationParser>(
2389 OperationType::REDUCE_PRODUCT);
2390 case kTfLiteBuiltinQuantize:
2391 if (allow_quant_ops) {
2392 return std::make_unique<QuantizeOperationParser>();
2393 }
2394 break;
2395 case kTfLiteBuiltinRelu:
2396 return std::make_unique<ReLUOperationParser>(0);
2397 case kTfLiteBuiltinRelu6:
2398 return std::make_unique<ReLUOperationParser>(6);
2399 case kTfLiteBuiltinLeakyRelu:
2400 return std::make_unique<ReLUOperationParser>(0);
2401 case kTfLiteBuiltinPrelu:
2402 return std::make_unique<PReLUOperationParser>();
2403 case kTfLiteBuiltinReshape:
2404 return std::make_unique<ReshapeOperationParser>();
2405 case kTfLiteBuiltinResizeBilinear:
2406 return std::make_unique<Resize2DOperationParser>(SamplingType::BILINEAR);
2407 case kTfLiteBuiltinResizeNearestNeighbor:
2408 return std::make_unique<Resize2DOperationParser>(SamplingType::NEAREST);
2409 case kTfLiteBuiltinRsqrt:
2410 return std::make_unique<ElementwiseOperationParser>(OperationType::RSQRT);
2411 case kTfLiteBuiltinSin:
2412 return std::make_unique<ElementwiseOperationParser>(OperationType::SIN);
2413 case kTfLiteBuiltinSlice:
2414 return std::make_unique<SliceOperationParser>();
2415 case kTfLiteBuiltinSoftmax:
2416 return std::make_unique<SoftmaxOperationParser>();
2417 case kTfLiteBuiltinSpaceToDepth:
2418 return std::make_unique<SpaceToDepthOperationParser>();
2419 case kTfLiteBuiltinSplitV:
2420 return std::make_unique<SplitVOperationParser>();
2421 case kTfLiteBuiltinSqrt:
2422 return std::make_unique<ElementwiseOperationParser>(OperationType::SQRT);
2423 case kTfLiteBuiltinSquare:
2424 return std::make_unique<ElementwiseOperationParser>(
2425 OperationType::SQUARE);
2426 case kTfLiteBuiltinSquaredDifference:
2427 return std::make_unique<ElementwiseOperationParser>(
2428 OperationType::SQUARED_DIFF);
2429 case kTfLiteBuiltinStridedSlice:
2430 return std::make_unique<StridedSliceOperationParser>();
2431 case kTfLiteBuiltinSub:
2432 return std::make_unique<ElementwiseOperationParser>(OperationType::SUB);
2433 case kTfLiteBuiltinSum:
2434 return std::make_unique<ReduceOperationParser>(OperationType::REDUCE_SUM);
2435 case kTfLiteBuiltinTanh:
2436 return std::make_unique<ElementwiseOperationParser>(OperationType::TANH);
2437 case kTfLiteBuiltinTranspose:
2438 return std::make_unique<TransposeOperationParser>();
2439 case kTfLiteBuiltinTransposeConv:
2440 return std::make_unique<TransposeConvBuiltinOperationParser>();
2441 case kTfLiteBuiltinCustom:
2442 return NewCustomOperationParser(registration->custom_name);
2443 }
2444 return std::make_unique<UnsupportedOperationParser>();
2445 }
2446
IsSupported(const TfLiteContext * context,TfLiteNode * node,const TfLiteRegistration * registration,bool allow_quant_ops=false)2447 absl::Status IsSupported(const TfLiteContext* context, TfLiteNode* node,
2448 const TfLiteRegistration* registration,
2449 bool allow_quant_ops = false) {
2450 return NewOperationParser(registration, allow_quant_ops)
2451 ->IsSupported(context, node, registration);
2452 }
2453
IsAllAllowedTensors(TfLiteContext * context,const TfLiteIntArray * tensor_indices,bool allow_quant_ops=false)2454 bool IsAllAllowedTensors(TfLiteContext* context,
2455 const TfLiteIntArray* tensor_indices,
2456 bool allow_quant_ops = false) {
2457 for (int i = 0; i < tensor_indices->size; ++i) {
2458 int tensor_idx = tensor_indices->data[i];
2459 if (tensor_idx == kTfLiteOptionalTensor) continue;
2460 const TfLiteTensor* t = &context->tensors[tensor_idx];
2461 bool type_supported =
2462 (t->type == kTfLiteFloat32 || t->type == kTfLiteFloat16);
2463 if (allow_quant_ops) {
2464 // Since we only check non-constant tensors, type cannot be Int32.
2465 type_supported =
2466 type_supported || t->type == kTfLiteInt8 || t->type == kTfLiteUInt8;
2467 }
2468 if (t->allocation_type == kTfLiteArenaRw && !type_supported) {
2469 return false;
2470 }
2471 }
2472 return true;
2473 }
2474 } // namespace
2475
2476 // TODO(impjdi): Check number of input/output tensors and their dimensions.
2477 // TODO(impjdi): Check ops' parameters.
GetOpsToReplace(TfLiteContext * context,bool allow_quant_ops,int max_delegated_partitions)2478 TfLiteIntArray* GetOpsToReplace(TfLiteContext* context, bool allow_quant_ops,
2479 int max_delegated_partitions) {
2480 delegates::IsNodeSupportedFn node_supported_fn =
2481 [=](TfLiteContext* context, TfLiteNode* node,
2482 TfLiteRegistration* registration,
2483 std::string* unsupported_details) -> bool {
2484 const auto status =
2485 IsSupported(context, node, registration, allow_quant_ops);
2486 if (!status.ok()) {
2487 if (unsupported_details) {
2488 *unsupported_details = std::string(status.message());
2489 }
2490 return false;
2491 }
2492
2493 if (!IsAllAllowedTensors(context, node->inputs, allow_quant_ops) ||
2494 !IsAllAllowedTensors(context, node->outputs, allow_quant_ops)) {
2495 if (unsupported_details) {
2496 *unsupported_details =
2497 "OP is supported, but tensor type isn't matched!";
2498 }
2499 return false;
2500 }
2501 return true;
2502 };
2503
2504 delegates::FP16GraphPartitionHelper partition_helper(context,
2505 node_supported_fn);
2506 std::set<std::string> unsupported_nodes_info;
2507 if (partition_helper.Partition(&unsupported_nodes_info) != kTfLiteOk) {
2508 return TfLiteIntArrayCreate(0);
2509 }
2510
2511 // By default, we simply get 1st largest partition as 'max_delegate_partions'
2512 // is set to 1 by default.
2513 std::vector<int> ops_to_replace =
2514 partition_helper.GetNodesOfFirstNLargestPartitions(
2515 max_delegated_partitions);
2516
2517 if (!unsupported_nodes_info.empty()) {
2518 std::string unsupported = absl::StrJoin(unsupported_nodes_info, "\n");
2519 std::string error_message = absl::StrCat(
2520 "Following operations are not supported by GPU delegate:\n",
2521 unsupported, "\n");
2522 if (!ops_to_replace.empty()) {
2523 absl::StrAppend(
2524 &error_message, ops_to_replace.size(),
2525 " operations will run on the GPU, and the remaining ",
2526 partition_helper.num_total_nodes() - ops_to_replace.size());
2527 } else {
2528 absl::StrAppend(&error_message,
2529 "No operations will run on the GPU, and all ",
2530 partition_helper.num_total_nodes());
2531 }
2532 absl::StrAppend(&error_message, " operations will run on the CPU.");
2533 TF_LITE_KERNEL_LOG(context, error_message.c_str());
2534 }
2535 return ConvertVectorToTfLiteIntArray(ops_to_replace);
2536 }
2537
2538 // Creates inputs and outputs passed by io_tensors parameters in the resulting
2539 // graph. We force it to make sure that delegated subgraph has same order of
2540 // inputs and outputs with the original one. When delegated model is built from
2541 // the tflite model representation tensors are created lazily, so there is no
2542 // guarantee that the order will match the source model tensors order.
PrecreateIOTensors(TfLiteContext * context,GraphFloat32 * graph,TfLiteIntArray * io_tensors,absl::flat_hash_map<int,int> * quant_conversion_map,absl::flat_hash_map<int,Value * > * tensor_to_value)2543 absl::Status PrecreateIOTensors(
2544 TfLiteContext* context, GraphFloat32* graph, TfLiteIntArray* io_tensors,
2545 absl::flat_hash_map<int, int>* quant_conversion_map,
2546 absl::flat_hash_map<int, Value*>* tensor_to_value) {
2547 for (int i = 0; i < io_tensors->size; ++i) {
2548 const int tensor_index = io_tensors->data[i];
2549 const TfLiteTensor& tflite_tensor = context->tensors[tensor_index];
2550 if (tflite::IsConstantTensor(&tflite_tensor)) continue;
2551 RETURN_IF_ERROR(ObjectReader::ReadNonConstantTensor(
2552 context, tensor_to_value, quant_conversion_map, graph, tensor_index));
2553 }
2554 return absl::OkStatus();
2555 }
2556
CopyVariableTensorOutputs(TfLiteNode * tflite_node,TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader & reader,const absl::flat_hash_map<int,ValueId> & new_variable_tensor_values)2557 absl::Status CopyVariableTensorOutputs(
2558 TfLiteNode* tflite_node, TfLiteRegistration* registration,
2559 GraphFloat32* graph, ObjectReader& reader,
2560 const absl::flat_hash_map<int, ValueId>& new_variable_tensor_values) {
2561 absl::flat_hash_map<int, ValueId> new_variable_tensor_values_copy(
2562 new_variable_tensor_values);
2563 // Retrieve the final value id for the variable input tensors.
2564 for (int i = 0; i < tflite_node->inputs->size; i++) {
2565 int tensor_idx = tflite_node->inputs->data[i];
2566 Value* value;
2567 if (!reader.ReadValueByTensorIdx(tensor_idx, &value).ok()) continue;
2568 if (value->tensor.is_variable_input) {
2569 if (new_variable_tensor_values_copy.find(i) ==
2570 new_variable_tensor_values_copy.end()) {
2571 return absl::InvalidArgumentError(
2572 absl::StrCat(GetOpNameByRegistration(*registration),
2573 " did not provide a new value for the variable input "
2574 "tensor with index ",
2575 tensor_idx));
2576 } else {
2577 Node* node = graph->NewNode();
2578 node->operation.type = ToString(OperationType::COPY);
2579 RETURN_IF_ERROR(graph->AddConsumer(
2580 node->id, new_variable_tensor_values_copy.at(i)));
2581 RETURN_IF_ERROR(reader.AddUpdate(node, i));
2582 new_variable_tensor_values_copy.erase(
2583 new_variable_tensor_values_copy.find(i));
2584 }
2585 }
2586 }
2587 if (!new_variable_tensor_values_copy.empty()) {
2588 return absl::InvalidArgumentError(
2589 "More input variable tensors asked to be copied than present on the "
2590 "node");
2591 }
2592 return absl::OkStatus();
2593 }
2594
BuildModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2595 absl::Status BuildModel(TfLiteContext* context,
2596 const TfLiteDelegateParams* delegate_params,
2597 GraphFloat32* graph,
2598 absl::flat_hash_map<int, int>* quant_conversion_map) {
2599 std::vector<std::unique_ptr<TFLiteOperationParser>> operations;
2600 std::vector<int> tflite_nodes;
2601 for (int i = 0; i < delegate_params->nodes_to_replace->size; ++i) {
2602 TfLiteNode* tflite_node = nullptr;
2603 TfLiteRegistration* registration = nullptr;
2604 RETURN_IF_ERROR(GetNodeAndRegistration(
2605 context, delegate_params->nodes_to_replace->data[i], &tflite_node,
2606 ®istration));
2607 if (registration->builtin_code == kTfLiteBuiltinDequantize &&
2608 context->tensors[tflite_node->inputs->data[0]].type ==
2609 TfLiteType::kTfLiteFloat16) {
2610 // Ignore Fp16 Dequantize nodes.
2611 continue;
2612 }
2613 auto op_parser = NewOperationParser(
2614 registration, /*allow_quant_ops=*/quant_conversion_map != nullptr);
2615 if (!op_parser) {
2616 return absl::UnimplementedError(
2617 absl::StrCat("Operation ", registration->builtin_code, "(",
2618 registration->custom_name,
2619 ") is not supported by TFLite GPU Delegate."));
2620 }
2621 operations.push_back(std::move(op_parser));
2622 tflite_nodes.push_back(i);
2623 }
2624 absl::flat_hash_map<int, Value*> tensor_to_value;
2625 std::vector<ValueId> variable_inputs_to_value_id;
2626 RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
2627 delegate_params->input_tensors,
2628 quant_conversion_map, &tensor_to_value));
2629 RETURN_IF_ERROR(PrecreateIOTensors(context, graph,
2630 delegate_params->output_tensors,
2631 quant_conversion_map, &tensor_to_value));
2632 for (int i = 0; i < operations.size(); ++i) {
2633 TfLiteNode* tflite_node;
2634 TfLiteRegistration* registration;
2635 RETURN_IF_ERROR(GetNodeAndRegistration(
2636 context, delegate_params->nodes_to_replace->data[tflite_nodes[i]],
2637 &tflite_node, ®istration));
2638 ObjectReader reader(graph, context, tflite_node, &tensor_to_value,
2639 quant_conversion_map);
2640 const auto status =
2641 operations[i]->Parse(tflite_node, registration, graph, &reader);
2642 if (!status.ok()) {
2643 return absl::InternalError(absl::StrCat(
2644 GetOpNameByRegistration(*registration), ": ", status.message()));
2645 }
2646
2647 absl::flat_hash_map<int, ValueId> new_value_for_variable_input_tensors =
2648 operations[i]->GetNewValueIdsForVariableInputNodes();
2649
2650 RETURN_IF_ERROR(
2651 CopyVariableTensorOutputs(tflite_node, registration, graph, reader,
2652 new_value_for_variable_input_tensors));
2653 }
2654
2655 // Variable input tensors expect to be unchanged throughout model execution.
2656 // They need to be an output of the graph in order to have them unchanged.
2657 for (auto value_id : variable_inputs_to_value_id) {
2658 if (!graph->IsGraphOutput(value_id)) {
2659 return absl::InvalidArgumentError(
2660 absl::StrCat("Variable input tensors must be a graph output. Value ",
2661 value_id, " is not a graph output"));
2662 }
2663 }
2664 return absl::OkStatus();
2665 }
2666
BuildFinalModel(TfLiteContext * context,const TfLiteDelegateParams * delegate_params,GraphFloat32 * graph,absl::flat_hash_map<int,int> * quant_conversion_map)2667 absl::Status BuildFinalModel(
2668 TfLiteContext* context, const TfLiteDelegateParams* delegate_params,
2669 GraphFloat32* graph, absl::flat_hash_map<int, int>* quant_conversion_map) {
2670 RETURN_IF_ERROR(
2671 BuildModel(context, delegate_params, graph, quant_conversion_map));
2672
2673 // Apply general transformations on the graph.
2674 NullTransformationReporter reporter;
2675 ModelTransformer transformer(graph, &reporter);
2676 if (!ApplyModelTransformations(&transformer)) {
2677 return absl::InternalError("Graph transformations failed");
2678 }
2679 return absl::OkStatus();
2680 }
2681
2682 } // namespace gpu
2683 } // namespace tflite
2684