1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_version.h"
16
17 #include <algorithm>
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 // #include "absl/memory/memory.h"
23 // #include "absl/strings/numbers.h"
24 // #include "absl/strings/str_split.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/builtin_op_data.h"
27 #include "tensorflow/lite/c/c_api_types.h"
28 #include "tensorflow/lite/kernels/internal/compatibility.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31
32 namespace tflite {
33 namespace {
34
NeedBroadcastForBinaryInputs(const OpSignature & op_sig)35 bool NeedBroadcastForBinaryInputs(const OpSignature& op_sig) {
36 if (op_sig.inputs.size() < 2) {
37 return false;
38 }
39 return (op_sig.inputs.at(0).dims != op_sig.inputs.at(1).dims);
40 }
41
GetInputMaxDims(const OpSignature & op_sig)42 int GetInputMaxDims(const OpSignature& op_sig) {
43 int max_dims = 0;
44 for (auto& input : op_sig.inputs) {
45 if (input.dims.size() > max_dims) {
46 max_dims = input.dims.size();
47 }
48 }
49 return max_dims;
50 }
51
52 } // namespace
53
GetBuiltinOperatorVersion(const OpSignature & op_sig)54 int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
55 switch (op_sig.op) {
56 case BuiltinOperator_CONV_2D:
57 if (op_sig.ext_options.conv_2d.is_grouped_convolution) {
58 return 6;
59 }
60 // If the op has signed int16 op_sig.inputs and op_sig.outputs, its
61 // version 4.
62 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
63 op_sig.inputs.at(1).type == kTfLiteInt16 &&
64 op_sig.outputs.at(1).type == kTfLiteInt16) {
65 return 4;
66 }
67
68 // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
69 // version 3.
70 if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
71 op_sig.inputs.at(1).type == kTfLiteInt8 &&
72 op_sig.outputs.at(0).type == kTfLiteInt8) {
73 return 3;
74 }
75 // If the op is a signed int8 hybrid operation, we need to return
76 // version 2 or 5 if per channel.
77 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
78 op_sig.inputs.at(1).type == kTfLiteInt8 &&
79 op_sig.outputs.at(0).type == kTfLiteFloat32) {
80 if (op_sig.ext_options.conv_2d.is_per_channel_quantized) {
81 return 5;
82 }
83 return 2;
84 }
85 return 1;
86
87 case BuiltinOperator_DEPTHWISE_CONV_2D: {
88 // If the op accepts int16, we return version 5.
89 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
90 op_sig.inputs.at(1).type == kTfLiteInt16 &&
91 op_sig.outputs.at(1).type == kTfLiteInt16) {
92 return 5;
93 }
94
95 // If the op is a signed int8 hybrid operation, we need to return
96 // version 4 or 6 if per-channel.
97 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
98 op_sig.inputs.at(1).type == kTfLiteInt8 &&
99 op_sig.outputs.at(0).type == kTfLiteFloat32) {
100 if (op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized) {
101 return 6;
102 }
103 return 4;
104 }
105 // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
106 // version 3.
107 if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
108 op_sig.inputs.at(1).type == kTfLiteInt8 &&
109 op_sig.outputs.at(0).type == kTfLiteInt8) {
110 return 3;
111 }
112 auto depthwise_conv_params =
113 reinterpret_cast<TfLiteDepthwiseConvParams*>(op_sig.builtin_data);
114 TFLITE_DCHECK(depthwise_conv_params != nullptr);
115 if (depthwise_conv_params->dilation_width_factor != 1 ||
116 depthwise_conv_params->dilation_height_factor != 1) {
117 return 2;
118 }
119 return 1;
120 }
121
122 case BuiltinOperator_FAKE_QUANT: {
123 auto fake_quant_params =
124 reinterpret_cast<TfLiteFakeQuantParams*>(op_sig.builtin_data);
125 TFLITE_DCHECK(fake_quant_params != nullptr);
126 if (fake_quant_params->narrow_range) {
127 return 2;
128 }
129 return 1;
130 }
131
132 case BuiltinOperator_FULLY_CONNECTED: {
133 // +-----------------+--------------------+--------------------------+
134 // | | Weight::Default | Weight::Shuffled4x16Int8 |
135 // +-----------------+--------------------+--------------------------+
136 // | Float | 1 | 2 |
137 // | Quantized Uint8 | 1 | 2 |
138 // | Hybrid | 3 | 3 |
139 // | Quantized Int8 | 4 | 4 |
140 // +-----------------+--------------------+--------------------------+
141
142 // FullyConnected with sparse weight is supported at version 8.
143 if (op_sig.ext_options.fully_connected.sparse_weight) {
144 return 8;
145 }
146
147 // Int16 fully fixed point kernel is at version 7.
148 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
149 op_sig.inputs.at(1).type == kTfLiteInt16 &&
150 op_sig.outputs.at(0).type == kTfLiteInt16) {
151 return 7;
152 }
153
154 // 2 op_sig.inputs (no bias) use case is supported starting from
155 // version 6.
156 if (op_sig.inputs.size() == 2) {
157 return 6;
158 }
159 auto fully_connected_params =
160 reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
161 TFLITE_DCHECK(fully_connected_params != nullptr);
162 // `keep_num_dims` is supported at version 5.
163 if (fully_connected_params->keep_num_dims) {
164 return 5;
165 }
166 // Int8 fully fixed point kernel is at version 4.
167 if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
168 op_sig.inputs.at(1).type == kTfLiteInt8 &&
169 op_sig.outputs.at(0).type == kTfLiteInt8) {
170 return 4;
171 }
172 // If the op is a signed int8 hybrid operation, we need to return
173 // version 3.
174 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
175 op_sig.inputs.at(1).type == kTfLiteInt8 &&
176 op_sig.outputs.at(0).type == kTfLiteFloat32) {
177 if (fully_connected_params->asymmetric_quantize_inputs) {
178 // This is to use the updated quantization scheme.
179 return 9;
180 }
181 return 3;
182 }
183 // For float and uint8 fixed point kernels, if the weight is
184 // Shuffled4x16Int8, it is version 2.
185 if (fully_connected_params->weights_format ==
186 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
187 return 2;
188 }
189 // Otherwise (weight is default), the version is 1.
190 return 1;
191 }
192
193 case BuiltinOperator_GATHER: {
194 auto gather_params =
195 reinterpret_cast<TfLiteGatherParams*>(op_sig.builtin_data);
196 if (gather_params && gather_params->batch_dims != 0) {
197 return 5;
198 }
199
200 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
201 return 4;
202 }
203 // If the op takes bool input, it is version 3.
204 if (op_sig.inputs.at(0).type == kTfLiteBool) {
205 return 3;
206 }
207 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
208 return 2;
209 }
210 return 1;
211 }
212
213 case BuiltinOperator_SVDF: {
214 // Fully integer SVDF has int8 as input and is of version 3.
215 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
216 return 3;
217 }
218 // If the op is a signed int8 hybrid operation, we need to return
219 // version 2.
220 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
221 op_sig.inputs.at(1).type == kTfLiteInt8 &&
222 op_sig.outputs.at(0).type == kTfLiteFloat32) {
223 auto svdf_params =
224 reinterpret_cast<TfLiteSVDFParams*>(op_sig.builtin_data);
225 // This is to use the updated quantization scheme
226 if (svdf_params && svdf_params->asymmetric_quantize_inputs) {
227 return 4;
228 }
229 return 2;
230 }
231 return 1;
232 }
233
234 case BuiltinOperator_MUL:
235 // Version 6 supports complex32 inputs
236 if (op_sig.inputs.at(0).type == kTfLiteComplex64) {
237 return 6;
238 }
239 // Version 5 supports int64 inputs
240 if (op_sig.inputs.at(0).type == kTfLiteInt64) {
241 return 5;
242 }
243 // Version 4 supports int16 inputs
244 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
245 return 4;
246 }
247 // Version 3 supports have a rescale value greater than or equal to 1.
248 if (op_sig.ext_options.mul.input1_scale != 0 &&
249 op_sig.ext_options.mul.input2_scale != 0 &&
250 op_sig.ext_options.mul.output_scale != 0 &&
251 (op_sig.ext_options.mul.input1_scale *
252 op_sig.ext_options.mul.input2_scale /
253 op_sig.ext_options.mul.output_scale) >= 1.0) {
254 return 3;
255 }
256 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
257 return 2;
258 }
259 return 1;
260
261 case BuiltinOperator_MAX_POOL_2D:
262 case BuiltinOperator_AVERAGE_POOL_2D:
263 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
264 op_sig.outputs.at(0).type == kTfLiteInt16) {
265 return 3;
266 }
267
268 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
269 return 2;
270 }
271 return 1;
272
273 case BuiltinOperator_TRANSPOSE:
274 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
275 return 5;
276 }
277 if (op_sig.inputs.at(0).dims.size() > 4) {
278 return 4;
279 }
280 // If the op takes bool input, it is version 3.
281 if (op_sig.inputs.at(0).type == kTfLiteBool) {
282 return 3;
283 }
284 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
285 return 2;
286 }
287 return 1;
288
289 case BuiltinOperator_TRANSPOSE_CONV: {
290 if (op_sig.inputs.size() == 4 &&
291 op_sig.inputs.at(3).type != kTfLiteNoType) {
292 return 3;
293 }
294 // If the op takes int8 input, it is version 2.
295 if (op_sig.inputs.at(1).type == kTfLiteInt8) {
296 return 2;
297 }
298 return 1;
299 }
300
301 case BuiltinOperator_LSTM: {
302 // If the input tensor is float and a weight is int8, this is a version
303 // 3 hybrid operation.
304 auto lstm_params =
305 reinterpret_cast<TfLiteLSTMParams*>(op_sig.builtin_data);
306 TFLITE_DCHECK(lstm_params != nullptr);
307 if (lstm_params->kernel_type == kTfLiteLSTMFullKernel &&
308 op_sig.inputs.at(0).type == kTfLiteFloat32 &&
309 op_sig.inputs.at(2).type == kTfLiteInt8 &&
310 op_sig.outputs.at(0).type == kTfLiteFloat32) {
311 if (lstm_params->asymmetric_quantize_inputs) {
312 return 4;
313 }
314 return 3;
315 }
316 // KERNEL_BASIC was added in version 2.
317 if (lstm_params->kernel_type == kTfLiteLSTMBasicKernel) {
318 return 2;
319 }
320 return 1;
321 }
322
323 case BuiltinOperator_SPLIT:
324 // If the op take in16 input, it is version 4.
325 if (op_sig.inputs.at(1).type == kTfLiteInt16) {
326 return 4;
327 }
328 // If the op take int8 input, it is version 2, for int32 it's version 3.
329 // The input tensor is at index 1 not 0, 0 is the axis.
330 if (op_sig.inputs.at(1).type == kTfLiteInt32) {
331 return 3;
332 }
333 if (op_sig.inputs.at(1).type == kTfLiteInt8) {
334 return 2;
335 }
336 return 1;
337
338 case BuiltinOperator_SPARSE_TO_DENSE:
339 // Version 3 supports Int8 and Uint8 type.
340 if (op_sig.inputs.at(2).type == kTfLiteInt8 ||
341 op_sig.inputs.at(2).type == kTfLiteUInt8) {
342 return 3;
343 }
344 // Version 2 supports Int64 value type.
345 if (op_sig.inputs.at(2).type == kTfLiteInt64) {
346 return 2;
347 }
348 return 1;
349
350 case BuiltinOperator_SLICE:
351 if (op_sig.inputs.at(0).dims.size() > 4) {
352 return 5;
353 }
354 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
355 return 4;
356 }
357 // Version 3 supports string input types.
358 if (op_sig.inputs.at(0).type == kTfLiteString) {
359 return 3;
360 }
361 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
362 return 2;
363 }
364 return 1;
365
366 case BuiltinOperator_UNPACK:
367 // If the op take int8/uint8 input, it is version 2.
368 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
369 op_sig.inputs.at(0).type == kTfLiteUInt8) {
370 return 2;
371 }
372 // If the op take bool input, it is version 3.
373 if (op_sig.inputs.at(0).type == kTfLiteBool) {
374 return 3;
375 }
376 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
377 op_sig.outputs.at(0).type == kTfLiteInt16) {
378 return 4;
379 }
380 return 1;
381
382 case BuiltinOperator_DEQUANTIZE:
383 // Version 3 supports signed int16 input types.
384 if (op_sig.inputs.at(0).type == kTfLiteInt16 ||
385 op_sig.inputs.at(0).type == kTfLiteFloat16) {
386 return 3;
387 }
388 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
389 if (op_sig.ext_options.dequantize.is_per_channel_quantized) {
390 return 5;
391 }
392 return 2;
393 }
394 return 1;
395
396 case BuiltinOperator_QUANTIZE:
397 if (op_sig.ext_options.quantize.is_per_channel_quantized) {
398 return 3;
399 }
400 if (op_sig.outputs.at(0).type == kTfLiteInt16) {
401 return 2;
402 }
403 return 1;
404
405 case BuiltinOperator_FLOOR_DIV:
406 if (op_sig.inputs.at(0).type == kTfLiteFloat32) {
407 return 2;
408 }
409 return 1;
410
411 case BuiltinOperator_L2_NORMALIZATION:
412 if (op_sig.outputs.at(0).type == kTfLiteInt8) {
413 return 2;
414 }
415 return 1;
416
417 case BuiltinOperator_ABS:
418 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
419 return op_sig.ext_options.abs.input_quantized ? 3 : 4;
420 }
421 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
422 op_sig.inputs.at(0).type == kTfLiteUInt8) {
423 return 2;
424 }
425 return 1;
426 case BuiltinOperator_RELU:
427 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
428 return 3;
429 }
430 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
431 op_sig.inputs.at(0).type == kTfLiteUInt8) {
432 return 2;
433 }
434 return 1;
435
436 case BuiltinOperator_STRIDED_SLICE: {
437 auto strided_slice_params =
438 reinterpret_cast<TfLiteStridedSliceParams*>(op_sig.builtin_data);
439 TFLITE_DCHECK(strided_slice_params != nullptr);
440 if (strided_slice_params->ellipsis_mask != 0 ||
441 strided_slice_params->new_axis_mask != 0) {
442 return 6;
443 }
444 if (op_sig.inputs.at(0).type == kTfLiteString) {
445 return 5;
446 }
447 if (op_sig.ext_options.strided_slice.num_dims > 4) {
448 return 4;
449 }
450 // If the op takes bool input, it is version 3.
451 if (op_sig.inputs.at(0).type == kTfLiteBool) {
452 return 3;
453 }
454 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
455 return 2;
456 }
457 return 1;
458 }
459 case BuiltinOperator_REVERSE_V2:
460 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
461 return 3;
462 }
463 if (op_sig.inputs.at(0).type == kTfLiteBool) {
464 return 2;
465 }
466 return 1;
467 case BuiltinOperator_RESIZE_BILINEAR: {
468 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
469 return 4;
470 }
471 auto resize_bilinear_params =
472 reinterpret_cast<TfLiteResizeBilinearParams*>(op_sig.builtin_data);
473 TFLITE_DCHECK(resize_bilinear_params != nullptr);
474 if (resize_bilinear_params->half_pixel_centers) {
475 return 3;
476 } else if (op_sig.inputs.at(0).type == kTfLiteInt8) {
477 return 2;
478 }
479 return 1;
480 }
481 case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
482 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
483 return 4;
484 }
485 auto resize_nearest_neighbor_params =
486 reinterpret_cast<TfLiteResizeNearestNeighborParams*>(
487 op_sig.builtin_data);
488 TFLITE_DCHECK(resize_nearest_neighbor_params != nullptr);
489 if (resize_nearest_neighbor_params->half_pixel_centers ||
490 resize_nearest_neighbor_params->align_corners) {
491 return 3;
492 } else if (op_sig.inputs.at(0).type == kTfLiteInt8) {
493 return 2;
494 }
495 return 1;
496 }
497
498 case BuiltinOperator_MAXIMUM:
499 case BuiltinOperator_MINIMUM:
500 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
501 op_sig.outputs.at(0).type == kTfLiteInt16) {
502 return 4;
503 }
504 if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
505 return 3;
506 }
507 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
508 return 2;
509 }
510 return 1;
511
512 case BuiltinOperator_PACK:
513 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
514 return 2;
515 }
516
517 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
518 op_sig.outputs.at(0).type == kTfLiteInt16) {
519 return 3;
520 }
521 return 1;
522
523 case BuiltinOperator_TILE:
524 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
525 return 3;
526 }
527 if (op_sig.inputs.at(0).type == kTfLiteString) {
528 return 2;
529 }
530 return 1;
531
532 case BuiltinOperator_SQUEEZE:
533 if (op_sig.inputs.at(0).type == kTfLiteString) {
534 return 2;
535 }
536 return 1;
537
538 case BuiltinOperator_SPACE_TO_BATCH_ND:
539 case BuiltinOperator_BATCH_TO_SPACE_ND:
540 if (op_sig.inputs.at(0).dims.size() != 4) {
541 return 3;
542 }
543 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
544 return 2;
545 }
546 return 1;
547
548 case BuiltinOperator_ADD: {
549 if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) {
550 return 4;
551 }
552 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
553 op_sig.outputs.at(0).type == kTfLiteInt16) {
554 auto add_params =
555 reinterpret_cast<TfLiteAddParams*>(op_sig.builtin_data);
556 if (add_params && !add_params->pot_scale_int16) {
557 return 3;
558 }
559 }
560 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
561 return 2;
562 }
563 return 1;
564 }
565
566 case BuiltinOperator_SUB: {
567 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
568 op_sig.outputs.at(0).type == kTfLiteInt16) {
569 auto sub_params =
570 reinterpret_cast<TfLiteSubParams*>(op_sig.builtin_data);
571 if (sub_params && !sub_params->pot_scale_int16) {
572 return 5;
573 }
574 }
575 if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) {
576 return 4;
577 }
578 if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
579 return 3;
580 }
581 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
582 return 2;
583 }
584 return 1;
585 }
586
587 case BuiltinOperator_GATHER_ND:
588 if (!op_sig.inputs.empty() &&
589 (op_sig.inputs.at(0).type == kTfLiteInt16)) {
590 return 3;
591 }
592 if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteString) {
593 return 2;
594 }
595 return 1;
596
597 case BuiltinOperator_DIV:
598 if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
599 return 2;
600 }
601 return 1;
602 case BuiltinOperator_TANH:
603 case BuiltinOperator_LOGISTIC:
604 if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
605 op_sig.outputs.at(0).type == kTfLiteInt16) {
606 return 3;
607 }
608
609 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
610 return 2;
611 }
612 return 1;
613
614 case BuiltinOperator_FILL:
615 if (op_sig.inputs.size() >= 2) {
616 if (op_sig.inputs.at(1).type == kTfLiteInt8 ||
617 op_sig.inputs.at(1).type == kTfLiteInt16) {
618 return 3;
619 } else if ((op_sig.inputs.at(1).type == kTfLiteBool ||
620 op_sig.inputs.at(1).type == kTfLiteString)) {
621 return 2;
622 }
623 }
624 return 1;
625
626 case BuiltinOperator_EQUAL:
627 case BuiltinOperator_NOT_EQUAL:
628 if (!op_sig.inputs.empty()) {
629 if (op_sig.inputs.at(0).type == kTfLiteString) {
630 return 3;
631 }
632 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
633 return 2;
634 }
635 }
636 return 1;
637
638 case BuiltinOperator_LEAKY_RELU:
639 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
640 return 2;
641 }
642 return 1;
643
644 case BuiltinOperator_BATCH_MATMUL: {
645 // In case of int16 inputs, the version is 3.
646 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
647 return 3;
648 }
649 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
650 return 2;
651 }
652 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
653 op_sig.inputs.at(1).type == kTfLiteInt8 &&
654 op_sig.outputs.at(0).type == kTfLiteFloat32) {
655 auto batch_mat_mul_params =
656 reinterpret_cast<TfLiteBatchMatMulParams*>(op_sig.builtin_data);
657 if (batch_mat_mul_params &&
658 batch_mat_mul_params->asymmetric_quantize_inputs) {
659 // This is to use the updated quantization scheme.
660 return 4;
661 }
662 }
663 return 1;
664 }
665
666 case BuiltinOperator_PAD:
667 case BuiltinOperator_PADV2:
668 if (op_sig.inputs.at(0).dims.size() > 4) {
669 return 4;
670 }
671 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
672 return 3;
673 }
674 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
675 return 2;
676 }
677 return 1;
678
679 case BuiltinOperator_CONCATENATION:
680 case BuiltinOperator_SOFTMAX:
681 case BuiltinOperator_MEAN:
682 case BuiltinOperator_REDUCE_MAX:
683 case BuiltinOperator_REDUCE_MIN:
684 case BuiltinOperator_RELU6:
685 // In case of int16 inputs, the version is 3.
686 if (op_sig.inputs.at(0).type == kTfLiteInt16) {
687 return 3;
688 }
689 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
690 return 2;
691 }
692 return 1;
693
694 case BuiltinOperator_RNN: {
695 if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
696 op_sig.outputs.at(0).type == kTfLiteFloat32) {
697 auto rnn_params =
698 reinterpret_cast<TfLiteRNNParams*>(op_sig.builtin_data);
699 if (rnn_params && rnn_params->asymmetric_quantize_inputs) {
700 return 3;
701 } else {
702 return 2;
703 }
704 }
705 return 1;
706 }
707
708 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
709 if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
710 op_sig.outputs.at(0).type == kTfLiteFloat32) {
711 auto sequence_rnn_params =
712 reinterpret_cast<TfLiteSequenceRNNParams*>(op_sig.builtin_data);
713 if (sequence_rnn_params &&
714 sequence_rnn_params->asymmetric_quantize_inputs) {
715 return 3;
716 } else {
717 return 2;
718 }
719 }
720 return 1;
721 }
722
723 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: {
724 if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
725 op_sig.outputs.at(0).type == kTfLiteFloat32) {
726 auto bidirectional_sequence_rnn_params =
727 reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
728 op_sig.builtin_data);
729 if (bidirectional_sequence_rnn_params &&
730 bidirectional_sequence_rnn_params->asymmetric_quantize_inputs) {
731 return 3;
732 } else {
733 return 2;
734 }
735 }
736 return 1;
737 }
738
739 case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
740 if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
741 op_sig.outputs.at(0).type == kTfLiteFloat32) {
742 auto bidirectional_sequence_lstm_params =
743 reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
744 op_sig.builtin_data);
745 if (bidirectional_sequence_lstm_params &&
746 bidirectional_sequence_lstm_params->asymmetric_quantize_inputs) {
747 return 3;
748 } else {
749 return 2;
750 }
751 }
752 return 1;
753 }
754
755 case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
756 // If the input tensor is float and a weight is int8, this is a version
757 // 2 hybrid operation.
758 if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
759 op_sig.inputs.at(2).type == kTfLiteInt8 &&
760 op_sig.outputs.at(0).type == kTfLiteFloat32) {
761 auto unidirectional_sequence_lstm_params =
762 reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
763 op_sig.builtin_data);
764 if (unidirectional_sequence_lstm_params &&
765 unidirectional_sequence_lstm_params->asymmetric_quantize_inputs) {
766 return 3;
767 }
768 return 2;
769 }
770 return 1;
771 }
772
773 case BuiltinOperator_ARG_MAX:
774 case BuiltinOperator_ARG_MIN:
775 if (op_sig.inputs.at(0).type == kTfLiteBool) {
776 return 3;
777 }
778 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
779 return 2;
780 }
781 return 1;
782
783 case BuiltinOperator_SELECT: {
784 if (op_sig.inputs.at(0).dims.size() == 5 ||
785 op_sig.inputs.at(1).dims.size() == 5 ||
786 op_sig.inputs.at(2).dims.size() == 5)
787 return 3;
788 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
789 return 2;
790 }
791 return 1;
792 }
793 case BuiltinOperator_SPACE_TO_DEPTH:
794 case BuiltinOperator_SPLIT_V:
795 case BuiltinOperator_SUM:
796 case BuiltinOperator_LOG_SOFTMAX:
797 case BuiltinOperator_TOPK_V2:
798 case BuiltinOperator_GREATER:
799 case BuiltinOperator_GREATER_EQUAL:
800 case BuiltinOperator_LESS:
801 case BuiltinOperator_LESS_EQUAL:
802 case BuiltinOperator_RSQRT:
803 case BuiltinOperator_SQUARED_DIFFERENCE:
804 case BuiltinOperator_DEPTH_TO_SPACE:
805 case BuiltinOperator_MIRROR_PAD:
806 if (op_sig.inputs.at(0).type == kTfLiteInt8) {
807 return 2;
808 }
809 return 1;
810
811 case BuiltinOperator_REDUCE_PROD:
812 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
813 op_sig.inputs.at(0).type == kTfLiteInt16) {
814 return 2;
815 }
816 return 1;
817
818 // The version one of broadcast to op won't be not supported since the
819 // version one was rollbacked and the builtin op code number has been
820 // changed because of builtin op code shortage problem.
821 // Quantized broadcast_to is version 3
822 case BuiltinOperator_BROADCAST_TO:
823 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
824 op_sig.inputs.at(0).type == kTfLiteInt16) {
825 return 3;
826 }
827 return 2;
828 case BuiltinOperator_CAST:
829 if (op_sig.inputs.at(0).type == kTfLiteUInt16 ||
830 op_sig.outputs.at(0).type == kTfLiteUInt16) {
831 return 4;
832 } else if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
833 op_sig.outputs.at(0).type == kTfLiteInt8) {
834 return 3;
835 } else if (op_sig.inputs.at(0).type == kTfLiteUInt32 ||
836 op_sig.outputs.at(0).type == kTfLiteUInt32) {
837 return 2;
838 }
839 return 1;
840 case BuiltinOperator_WHERE:
841 if (op_sig.inputs.at(0).type == kTfLiteBool) return 1;
842 return 2;
843 case BuiltinOperator_GELU:
844 if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
845 op_sig.inputs.at(0).type == kTfLiteUInt8) {
846 return 2;
847 }
848 return 1;
849 default:
850 return 1;
851 }
852 // Prevent lint error about this function being too long.
853 // NOLINTNEXTLINE
854 }
855
UpdateOpVersion(uint8_t * model_buffer_pointer)856 void UpdateOpVersion(uint8_t* model_buffer_pointer) {
857 auto model = GetMutableModel(model_buffer_pointer);
858 auto subgraphs = model->subgraphs();
859
860 for (int i = 0; i < subgraphs->Length(); ++i) {
861 const SubGraph* subgraph = subgraphs->Get(i);
862 for (int j = 0; j < subgraph->operators()->Length(); ++j) {
863 const Operator* op = subgraph->operators()->Get(j);
864 OperatorCode* op_code =
865 model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
866
867 auto builtin_code = GetBuiltinCode(op_code);
868 if (builtin_code != BuiltinOperator_CUSTOM) {
869 OpSignature op_sig = GetOpSignature(op_code, op, subgraph, model);
870 // Update builtin operator version.
871 int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
872 if (op_sig.builtin_data) {
873 free(op_sig.builtin_data);
874 }
875 // Skip updating op version if the current node uses lower version.
876 // TODO(b/184366869): Populate multiple versions of operator once MLIR
877 // quantizer is ready.
878 if (op_ver <= op_code->version()) {
879 continue;
880 }
881 if (!op_code->mutate_version(op_ver)) {
882 LOG(ERROR) << "Can't set operator "
883 << EnumNameBuiltinOperator(builtin_code) << " to version "
884 << op_ver;
885 }
886 }
887 }
888 }
889 }
890
891 } // namespace tflite
892