• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_version.h"
16 
17 #include <algorithm>
18 #include <string>
19 #include <utility>
20 #include <vector>
21 
22 // #include "absl/memory/memory.h"
23 // #include "absl/strings/numbers.h"
24 // #include "absl/strings/str_split.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/builtin_op_data.h"
27 #include "tensorflow/lite/c/c_api_types.h"
28 #include "tensorflow/lite/kernels/internal/compatibility.h"
29 #include "tensorflow/lite/schema/schema_generated.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 
32 namespace tflite {
33 namespace {
34 
NeedBroadcastForBinaryInputs(const OpSignature & op_sig)35 bool NeedBroadcastForBinaryInputs(const OpSignature& op_sig) {
36   if (op_sig.inputs.size() < 2) {
37     return false;
38   }
39   return (op_sig.inputs.at(0).dims != op_sig.inputs.at(1).dims);
40 }
41 
GetInputMaxDims(const OpSignature & op_sig)42 int GetInputMaxDims(const OpSignature& op_sig) {
43   int max_dims = 0;
44   for (auto& input : op_sig.inputs) {
45     if (input.dims.size() > max_dims) {
46       max_dims = input.dims.size();
47     }
48   }
49   return max_dims;
50 }
51 
52 }  // namespace
53 
GetBuiltinOperatorVersion(const OpSignature & op_sig)54 int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
55   switch (op_sig.op) {
56     case BuiltinOperator_CONV_2D:
57       if (op_sig.ext_options.conv_2d.is_grouped_convolution) {
58         return 6;
59       }
60       // If the op has signed int16 op_sig.inputs and op_sig.outputs, its
61       // version 4.
62       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
63           op_sig.inputs.at(1).type == kTfLiteInt16 &&
64           op_sig.outputs.at(1).type == kTfLiteInt16) {
65         return 4;
66       }
67 
68       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
69       // version 3.
70       if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
71           op_sig.inputs.at(1).type == kTfLiteInt8 &&
72           op_sig.outputs.at(0).type == kTfLiteInt8) {
73         return 3;
74       }
75       // If the op is a signed int8 hybrid operation, we need to return
76       // version 2 or 5 if per channel.
77       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
78           op_sig.inputs.at(1).type == kTfLiteInt8 &&
79           op_sig.outputs.at(0).type == kTfLiteFloat32) {
80         if (op_sig.ext_options.conv_2d.is_per_channel_quantized) {
81           return 5;
82         }
83         return 2;
84       }
85       return 1;
86 
87     case BuiltinOperator_DEPTHWISE_CONV_2D: {
88       // If the op accepts int16, we return version 5.
89       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
90           op_sig.inputs.at(1).type == kTfLiteInt16 &&
91           op_sig.outputs.at(1).type == kTfLiteInt16) {
92         return 5;
93       }
94 
95       // If the op is a signed int8 hybrid operation, we need to return
96       // version 4 or 6 if per-channel.
97       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
98           op_sig.inputs.at(1).type == kTfLiteInt8 &&
99           op_sig.outputs.at(0).type == kTfLiteFloat32) {
100         if (op_sig.ext_options.depthwise_conv_2d.is_per_channel_quantized) {
101           return 6;
102         }
103         return 4;
104       }
105       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
106       // version 3.
107       if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
108           op_sig.inputs.at(1).type == kTfLiteInt8 &&
109           op_sig.outputs.at(0).type == kTfLiteInt8) {
110         return 3;
111       }
112       auto depthwise_conv_params =
113           reinterpret_cast<TfLiteDepthwiseConvParams*>(op_sig.builtin_data);
114       TFLITE_DCHECK(depthwise_conv_params != nullptr);
115       if (depthwise_conv_params->dilation_width_factor != 1 ||
116           depthwise_conv_params->dilation_height_factor != 1) {
117         return 2;
118       }
119       return 1;
120     }
121 
122     case BuiltinOperator_FAKE_QUANT: {
123       auto fake_quant_params =
124           reinterpret_cast<TfLiteFakeQuantParams*>(op_sig.builtin_data);
125       TFLITE_DCHECK(fake_quant_params != nullptr);
126       if (fake_quant_params->narrow_range) {
127         return 2;
128       }
129       return 1;
130     }
131 
132     case BuiltinOperator_FULLY_CONNECTED: {
133       // +-----------------+--------------------+--------------------------+
134       // |                 |    Weight::Default | Weight::Shuffled4x16Int8 |
135       // +-----------------+--------------------+--------------------------+
136       // | Float           |                  1 |                        2 |
137       // | Quantized Uint8 |                  1 |                        2 |
138       // | Hybrid          |                  3 |                        3 |
139       // | Quantized Int8  |                  4 |                        4 |
140       // +-----------------+--------------------+--------------------------+
141 
142       // FullyConnected with sparse weight is supported at version 8.
143       if (op_sig.ext_options.fully_connected.sparse_weight) {
144         return 8;
145       }
146 
147       // Int16 fully fixed point kernel is at version 7.
148       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
149           op_sig.inputs.at(1).type == kTfLiteInt16 &&
150           op_sig.outputs.at(0).type == kTfLiteInt16) {
151         return 7;
152       }
153 
154       // 2 op_sig.inputs (no bias) use case is supported starting from
155       // version 6.
156       if (op_sig.inputs.size() == 2) {
157         return 6;
158       }
159       auto fully_connected_params =
160           reinterpret_cast<TfLiteFullyConnectedParams*>(op_sig.builtin_data);
161       TFLITE_DCHECK(fully_connected_params != nullptr);
162       // `keep_num_dims` is supported at version 5.
163       if (fully_connected_params->keep_num_dims) {
164         return 5;
165       }
166       // Int8 fully fixed point kernel is at version 4.
167       if (op_sig.inputs.at(0).type == kTfLiteInt8 &&
168           op_sig.inputs.at(1).type == kTfLiteInt8 &&
169           op_sig.outputs.at(0).type == kTfLiteInt8) {
170         return 4;
171       }
172       // If the op is a signed int8 hybrid operation, we need to return
173       // version 3.
174       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
175           op_sig.inputs.at(1).type == kTfLiteInt8 &&
176           op_sig.outputs.at(0).type == kTfLiteFloat32) {
177         if (fully_connected_params->asymmetric_quantize_inputs) {
178           // This is to use the updated quantization scheme.
179           return 9;
180         }
181         return 3;
182       }
183       // For float and uint8 fixed point kernels, if the weight is
184       // Shuffled4x16Int8, it is version 2.
185       if (fully_connected_params->weights_format ==
186           kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
187         return 2;
188       }
189       // Otherwise (weight is default), the version is 1.
190       return 1;
191     }
192 
193     case BuiltinOperator_GATHER: {
194       auto gather_params =
195           reinterpret_cast<TfLiteGatherParams*>(op_sig.builtin_data);
196       if (gather_params && gather_params->batch_dims != 0) {
197         return 5;
198       }
199 
200       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
201         return 4;
202       }
203       // If the op takes bool input, it is version 3.
204       if (op_sig.inputs.at(0).type == kTfLiteBool) {
205         return 3;
206       }
207       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
208         return 2;
209       }
210       return 1;
211     }
212 
213     case BuiltinOperator_SVDF: {
214       // Fully integer SVDF has int8 as input and is of version 3.
215       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
216         return 3;
217       }
218       // If the op is a signed int8 hybrid operation, we need to return
219       // version 2.
220       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
221           op_sig.inputs.at(1).type == kTfLiteInt8 &&
222           op_sig.outputs.at(0).type == kTfLiteFloat32) {
223         auto svdf_params =
224             reinterpret_cast<TfLiteSVDFParams*>(op_sig.builtin_data);
225         // This is to use the updated quantization scheme
226         if (svdf_params && svdf_params->asymmetric_quantize_inputs) {
227           return 4;
228         }
229         return 2;
230       }
231       return 1;
232     }
233 
234     case BuiltinOperator_MUL:
235       // Version 6 supports complex32 inputs
236       if (op_sig.inputs.at(0).type == kTfLiteComplex64) {
237         return 6;
238       }
239       // Version 5 supports int64 inputs
240       if (op_sig.inputs.at(0).type == kTfLiteInt64) {
241         return 5;
242       }
243       // Version 4 supports int16 inputs
244       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
245         return 4;
246       }
247       // Version 3 supports have a rescale value greater than or equal to 1.
248       if (op_sig.ext_options.mul.input1_scale != 0 &&
249           op_sig.ext_options.mul.input2_scale != 0 &&
250           op_sig.ext_options.mul.output_scale != 0 &&
251           (op_sig.ext_options.mul.input1_scale *
252            op_sig.ext_options.mul.input2_scale /
253            op_sig.ext_options.mul.output_scale) >= 1.0) {
254         return 3;
255       }
256       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
257         return 2;
258       }
259       return 1;
260 
261     case BuiltinOperator_MAX_POOL_2D:
262     case BuiltinOperator_AVERAGE_POOL_2D:
263       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
264           op_sig.outputs.at(0).type == kTfLiteInt16) {
265         return 3;
266       }
267 
268       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
269         return 2;
270       }
271       return 1;
272 
273     case BuiltinOperator_TRANSPOSE:
274       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
275         return 5;
276       }
277       if (op_sig.inputs.at(0).dims.size() > 4) {
278         return 4;
279       }
280       // If the op takes bool input, it is version 3.
281       if (op_sig.inputs.at(0).type == kTfLiteBool) {
282         return 3;
283       }
284       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
285         return 2;
286       }
287       return 1;
288 
289     case BuiltinOperator_TRANSPOSE_CONV: {
290       if (op_sig.inputs.size() == 4 &&
291           op_sig.inputs.at(3).type != kTfLiteNoType) {
292         return 3;
293       }
294       // If the op takes int8 input, it is version 2.
295       if (op_sig.inputs.at(1).type == kTfLiteInt8) {
296         return 2;
297       }
298       return 1;
299     }
300 
301     case BuiltinOperator_LSTM: {
302       // If the input tensor is float and a weight is int8, this is a version
303       // 3 hybrid operation.
304       auto lstm_params =
305           reinterpret_cast<TfLiteLSTMParams*>(op_sig.builtin_data);
306       TFLITE_DCHECK(lstm_params != nullptr);
307       if (lstm_params->kernel_type == kTfLiteLSTMFullKernel &&
308           op_sig.inputs.at(0).type == kTfLiteFloat32 &&
309           op_sig.inputs.at(2).type == kTfLiteInt8 &&
310           op_sig.outputs.at(0).type == kTfLiteFloat32) {
311         if (lstm_params->asymmetric_quantize_inputs) {
312           return 4;
313         }
314         return 3;
315       }
316       // KERNEL_BASIC was added in version 2.
317       if (lstm_params->kernel_type == kTfLiteLSTMBasicKernel) {
318         return 2;
319       }
320       return 1;
321     }
322 
323     case BuiltinOperator_SPLIT:
324       // If the op take in16 input, it is version 4.
325       if (op_sig.inputs.at(1).type == kTfLiteInt16) {
326         return 4;
327       }
328       // If the op take int8 input, it is version 2, for int32 it's version 3.
329       // The input tensor is at index 1 not 0, 0 is the axis.
330       if (op_sig.inputs.at(1).type == kTfLiteInt32) {
331         return 3;
332       }
333       if (op_sig.inputs.at(1).type == kTfLiteInt8) {
334         return 2;
335       }
336       return 1;
337 
338     case BuiltinOperator_SPARSE_TO_DENSE:
339       // Version 3 supports Int8 and Uint8 type.
340       if (op_sig.inputs.at(2).type == kTfLiteInt8 ||
341           op_sig.inputs.at(2).type == kTfLiteUInt8) {
342         return 3;
343       }
344       // Version 2 supports Int64 value type.
345       if (op_sig.inputs.at(2).type == kTfLiteInt64) {
346         return 2;
347       }
348       return 1;
349 
350     case BuiltinOperator_SLICE:
351       if (op_sig.inputs.at(0).dims.size() > 4) {
352         return 5;
353       }
354       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
355         return 4;
356       }
357       // Version 3 supports string input types.
358       if (op_sig.inputs.at(0).type == kTfLiteString) {
359         return 3;
360       }
361       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
362         return 2;
363       }
364       return 1;
365 
366     case BuiltinOperator_UNPACK:
367       // If the op take int8/uint8 input, it is version 2.
368       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
369           op_sig.inputs.at(0).type == kTfLiteUInt8) {
370         return 2;
371       }
372       // If the op take bool input, it is version 3.
373       if (op_sig.inputs.at(0).type == kTfLiteBool) {
374         return 3;
375       }
376       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
377           op_sig.outputs.at(0).type == kTfLiteInt16) {
378         return 4;
379       }
380       return 1;
381 
382     case BuiltinOperator_DEQUANTIZE:
383       // Version 3 supports signed int16 input types.
384       if (op_sig.inputs.at(0).type == kTfLiteInt16 ||
385           op_sig.inputs.at(0).type == kTfLiteFloat16) {
386         return 3;
387       }
388       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
389         if (op_sig.ext_options.dequantize.is_per_channel_quantized) {
390           return 5;
391         }
392         return 2;
393       }
394       return 1;
395 
396     case BuiltinOperator_QUANTIZE:
397       if (op_sig.ext_options.quantize.is_per_channel_quantized) {
398         return 3;
399       }
400       if (op_sig.outputs.at(0).type == kTfLiteInt16) {
401         return 2;
402       }
403       return 1;
404 
405     case BuiltinOperator_FLOOR_DIV:
406       if (op_sig.inputs.at(0).type == kTfLiteFloat32) {
407         return 2;
408       }
409       return 1;
410 
411     case BuiltinOperator_L2_NORMALIZATION:
412       if (op_sig.outputs.at(0).type == kTfLiteInt8) {
413         return 2;
414       }
415       return 1;
416 
417     case BuiltinOperator_ABS:
418       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
419         return op_sig.ext_options.abs.input_quantized ? 3 : 4;
420       }
421       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
422           op_sig.inputs.at(0).type == kTfLiteUInt8) {
423         return 2;
424       }
425       return 1;
426     case BuiltinOperator_RELU:
427       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
428         return 3;
429       }
430       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
431           op_sig.inputs.at(0).type == kTfLiteUInt8) {
432         return 2;
433       }
434       return 1;
435 
436     case BuiltinOperator_STRIDED_SLICE: {
437       auto strided_slice_params =
438           reinterpret_cast<TfLiteStridedSliceParams*>(op_sig.builtin_data);
439       TFLITE_DCHECK(strided_slice_params != nullptr);
440       if (strided_slice_params->ellipsis_mask != 0 ||
441           strided_slice_params->new_axis_mask != 0) {
442         return 6;
443       }
444       if (op_sig.inputs.at(0).type == kTfLiteString) {
445         return 5;
446       }
447       if (op_sig.ext_options.strided_slice.num_dims > 4) {
448         return 4;
449       }
450       // If the op takes bool input, it is version 3.
451       if (op_sig.inputs.at(0).type == kTfLiteBool) {
452         return 3;
453       }
454       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
455         return 2;
456       }
457       return 1;
458     }
459     case BuiltinOperator_REVERSE_V2:
460       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
461         return 3;
462       }
463       if (op_sig.inputs.at(0).type == kTfLiteBool) {
464         return 2;
465       }
466       return 1;
467     case BuiltinOperator_RESIZE_BILINEAR: {
468       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
469         return 4;
470       }
471       auto resize_bilinear_params =
472           reinterpret_cast<TfLiteResizeBilinearParams*>(op_sig.builtin_data);
473       TFLITE_DCHECK(resize_bilinear_params != nullptr);
474       if (resize_bilinear_params->half_pixel_centers) {
475         return 3;
476       } else if (op_sig.inputs.at(0).type == kTfLiteInt8) {
477         return 2;
478       }
479       return 1;
480     }
481     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: {
482       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
483         return 4;
484       }
485       auto resize_nearest_neighbor_params =
486           reinterpret_cast<TfLiteResizeNearestNeighborParams*>(
487               op_sig.builtin_data);
488       TFLITE_DCHECK(resize_nearest_neighbor_params != nullptr);
489       if (resize_nearest_neighbor_params->half_pixel_centers ||
490           resize_nearest_neighbor_params->align_corners) {
491         return 3;
492       } else if (op_sig.inputs.at(0).type == kTfLiteInt8) {
493         return 2;
494       }
495       return 1;
496     }
497 
498     case BuiltinOperator_MAXIMUM:
499     case BuiltinOperator_MINIMUM:
500       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
501           op_sig.outputs.at(0).type == kTfLiteInt16) {
502         return 4;
503       }
504       if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
505         return 3;
506       }
507       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
508         return 2;
509       }
510       return 1;
511 
512     case BuiltinOperator_PACK:
513       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
514         return 2;
515       }
516 
517       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
518           op_sig.outputs.at(0).type == kTfLiteInt16) {
519         return 3;
520       }
521       return 1;
522 
523     case BuiltinOperator_TILE:
524       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
525         return 3;
526       }
527       if (op_sig.inputs.at(0).type == kTfLiteString) {
528         return 2;
529       }
530       return 1;
531 
532     case BuiltinOperator_SQUEEZE:
533       if (op_sig.inputs.at(0).type == kTfLiteString) {
534         return 2;
535       }
536       return 1;
537 
538     case BuiltinOperator_SPACE_TO_BATCH_ND:
539     case BuiltinOperator_BATCH_TO_SPACE_ND:
540       if (op_sig.inputs.at(0).dims.size() != 4) {
541         return 3;
542       }
543       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
544         return 2;
545       }
546       return 1;
547 
548     case BuiltinOperator_ADD: {
549       if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) {
550         return 4;
551       }
552       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
553           op_sig.outputs.at(0).type == kTfLiteInt16) {
554         auto add_params =
555             reinterpret_cast<TfLiteAddParams*>(op_sig.builtin_data);
556         if (add_params && !add_params->pot_scale_int16) {
557           return 3;
558         }
559       }
560       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
561         return 2;
562       }
563       return 1;
564     }
565 
566     case BuiltinOperator_SUB: {
567       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
568           op_sig.outputs.at(0).type == kTfLiteInt16) {
569         auto sub_params =
570             reinterpret_cast<TfLiteSubParams*>(op_sig.builtin_data);
571         if (sub_params && !sub_params->pot_scale_int16) {
572           return 5;
573         }
574       }
575       if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteInt64) {
576         return 4;
577       }
578       if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
579         return 3;
580       }
581       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
582         return 2;
583       }
584       return 1;
585     }
586 
587     case BuiltinOperator_GATHER_ND:
588       if (!op_sig.inputs.empty() &&
589           (op_sig.inputs.at(0).type == kTfLiteInt16)) {
590         return 3;
591       }
592       if (!op_sig.inputs.empty() && op_sig.inputs.at(0).type == kTfLiteString) {
593         return 2;
594       }
595       return 1;
596 
597     case BuiltinOperator_DIV:
598       if (NeedBroadcastForBinaryInputs(op_sig) && GetInputMaxDims(op_sig) > 4) {
599         return 2;
600       }
601       return 1;
602     case BuiltinOperator_TANH:
603     case BuiltinOperator_LOGISTIC:
604       if (op_sig.inputs.at(0).type == kTfLiteInt16 &&
605           op_sig.outputs.at(0).type == kTfLiteInt16) {
606         return 3;
607       }
608 
609       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
610         return 2;
611       }
612       return 1;
613 
614     case BuiltinOperator_FILL:
615       if (op_sig.inputs.size() >= 2) {
616         if (op_sig.inputs.at(1).type == kTfLiteInt8 ||
617             op_sig.inputs.at(1).type == kTfLiteInt16) {
618           return 3;
619         } else if ((op_sig.inputs.at(1).type == kTfLiteBool ||
620                     op_sig.inputs.at(1).type == kTfLiteString)) {
621           return 2;
622         }
623       }
624       return 1;
625 
626     case BuiltinOperator_EQUAL:
627     case BuiltinOperator_NOT_EQUAL:
628       if (!op_sig.inputs.empty()) {
629         if (op_sig.inputs.at(0).type == kTfLiteString) {
630           return 3;
631         }
632         if (op_sig.inputs.at(0).type == kTfLiteInt8) {
633           return 2;
634         }
635       }
636       return 1;
637 
638     case BuiltinOperator_LEAKY_RELU:
639       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
640         return 2;
641       }
642       return 1;
643 
644     case BuiltinOperator_BATCH_MATMUL: {
645       // In case of int16 inputs, the version is 3.
646       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
647         return 3;
648       }
649       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
650         return 2;
651       }
652       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
653           op_sig.inputs.at(1).type == kTfLiteInt8 &&
654           op_sig.outputs.at(0).type == kTfLiteFloat32) {
655         auto batch_mat_mul_params =
656             reinterpret_cast<TfLiteBatchMatMulParams*>(op_sig.builtin_data);
657         if (batch_mat_mul_params &&
658             batch_mat_mul_params->asymmetric_quantize_inputs) {
659           // This is to use the updated quantization scheme.
660           return 4;
661         }
662       }
663       return 1;
664     }
665 
666     case BuiltinOperator_PAD:
667     case BuiltinOperator_PADV2:
668       if (op_sig.inputs.at(0).dims.size() > 4) {
669         return 4;
670       }
671       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
672         return 3;
673       }
674       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
675         return 2;
676       }
677       return 1;
678 
679     case BuiltinOperator_CONCATENATION:
680     case BuiltinOperator_SOFTMAX:
681     case BuiltinOperator_MEAN:
682     case BuiltinOperator_REDUCE_MAX:
683     case BuiltinOperator_REDUCE_MIN:
684     case BuiltinOperator_RELU6:
685       // In case of int16 inputs, the version is 3.
686       if (op_sig.inputs.at(0).type == kTfLiteInt16) {
687         return 3;
688       }
689       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
690         return 2;
691       }
692       return 1;
693 
694     case BuiltinOperator_RNN: {
695       if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
696           op_sig.outputs.at(0).type == kTfLiteFloat32) {
697         auto rnn_params =
698             reinterpret_cast<TfLiteRNNParams*>(op_sig.builtin_data);
699         if (rnn_params && rnn_params->asymmetric_quantize_inputs) {
700           return 3;
701         } else {
702           return 2;
703         }
704       }
705       return 1;
706     }
707 
708     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
709       if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
710           op_sig.outputs.at(0).type == kTfLiteFloat32) {
711         auto sequence_rnn_params =
712             reinterpret_cast<TfLiteSequenceRNNParams*>(op_sig.builtin_data);
713         if (sequence_rnn_params &&
714             sequence_rnn_params->asymmetric_quantize_inputs) {
715           return 3;
716         } else {
717           return 2;
718         }
719       }
720       return 1;
721     }
722 
723     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: {
724       if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
725           op_sig.outputs.at(0).type == kTfLiteFloat32) {
726         auto bidirectional_sequence_rnn_params =
727             reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
728                 op_sig.builtin_data);
729         if (bidirectional_sequence_rnn_params &&
730             bidirectional_sequence_rnn_params->asymmetric_quantize_inputs) {
731           return 3;
732         } else {
733           return 2;
734         }
735       }
736       return 1;
737     }
738 
739     case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
740       if (op_sig.inputs.at(1).type == kTfLiteInt8 &&
741           op_sig.outputs.at(0).type == kTfLiteFloat32) {
742         auto bidirectional_sequence_lstm_params =
743             reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
744                 op_sig.builtin_data);
745         if (bidirectional_sequence_lstm_params &&
746             bidirectional_sequence_lstm_params->asymmetric_quantize_inputs) {
747           return 3;
748         } else {
749           return 2;
750         }
751       }
752       return 1;
753     }
754 
755     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
756       // If the input tensor is float and a weight is int8, this is a version
757       // 2 hybrid operation.
758       if (op_sig.inputs.at(0).type == kTfLiteFloat32 &&
759           op_sig.inputs.at(2).type == kTfLiteInt8 &&
760           op_sig.outputs.at(0).type == kTfLiteFloat32) {
761         auto unidirectional_sequence_lstm_params =
762             reinterpret_cast<TfLiteUnidirectionalSequenceLSTMParams*>(
763                 op_sig.builtin_data);
764         if (unidirectional_sequence_lstm_params &&
765             unidirectional_sequence_lstm_params->asymmetric_quantize_inputs) {
766           return 3;
767         }
768         return 2;
769       }
770       return 1;
771     }
772 
773     case BuiltinOperator_ARG_MAX:
774     case BuiltinOperator_ARG_MIN:
775       if (op_sig.inputs.at(0).type == kTfLiteBool) {
776         return 3;
777       }
778       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
779         return 2;
780       }
781       return 1;
782 
783     case BuiltinOperator_SELECT: {
784       if (op_sig.inputs.at(0).dims.size() == 5 ||
785           op_sig.inputs.at(1).dims.size() == 5 ||
786           op_sig.inputs.at(2).dims.size() == 5)
787         return 3;
788       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
789         return 2;
790       }
791       return 1;
792     }
793     case BuiltinOperator_SPACE_TO_DEPTH:
794     case BuiltinOperator_SPLIT_V:
795     case BuiltinOperator_SUM:
796     case BuiltinOperator_LOG_SOFTMAX:
797     case BuiltinOperator_TOPK_V2:
798     case BuiltinOperator_GREATER:
799     case BuiltinOperator_GREATER_EQUAL:
800     case BuiltinOperator_LESS:
801     case BuiltinOperator_LESS_EQUAL:
802     case BuiltinOperator_RSQRT:
803     case BuiltinOperator_SQUARED_DIFFERENCE:
804     case BuiltinOperator_DEPTH_TO_SPACE:
805     case BuiltinOperator_MIRROR_PAD:
806       if (op_sig.inputs.at(0).type == kTfLiteInt8) {
807         return 2;
808       }
809       return 1;
810 
811     case BuiltinOperator_REDUCE_PROD:
812       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
813           op_sig.inputs.at(0).type == kTfLiteInt16) {
814         return 2;
815       }
816       return 1;
817 
818     // The version one of broadcast to op won't be not supported since the
819     // version one was rollbacked and the builtin op code number has been
820     // changed because of builtin op code shortage problem.
821     // Quantized broadcast_to is version 3
822     case BuiltinOperator_BROADCAST_TO:
823       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
824           op_sig.inputs.at(0).type == kTfLiteInt16) {
825         return 3;
826       }
827       return 2;
828     case BuiltinOperator_CAST:
829       if (op_sig.inputs.at(0).type == kTfLiteUInt16 ||
830           op_sig.outputs.at(0).type == kTfLiteUInt16) {
831         return 4;
832       } else if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
833                  op_sig.outputs.at(0).type == kTfLiteInt8) {
834         return 3;
835       } else if (op_sig.inputs.at(0).type == kTfLiteUInt32 ||
836                  op_sig.outputs.at(0).type == kTfLiteUInt32) {
837         return 2;
838       }
839       return 1;
840     case BuiltinOperator_WHERE:
841       if (op_sig.inputs.at(0).type == kTfLiteBool) return 1;
842       return 2;
843     case BuiltinOperator_GELU:
844       if (op_sig.inputs.at(0).type == kTfLiteInt8 ||
845           op_sig.inputs.at(0).type == kTfLiteUInt8) {
846         return 2;
847       }
848       return 1;
849     default:
850       return 1;
851   }
852   // Prevent lint error about this function being too long.
853   // NOLINTNEXTLINE
854 }
855 
UpdateOpVersion(uint8_t * model_buffer_pointer)856 void UpdateOpVersion(uint8_t* model_buffer_pointer) {
857   auto model = GetMutableModel(model_buffer_pointer);
858   auto subgraphs = model->subgraphs();
859 
860   for (int i = 0; i < subgraphs->Length(); ++i) {
861     const SubGraph* subgraph = subgraphs->Get(i);
862     for (int j = 0; j < subgraph->operators()->Length(); ++j) {
863       const Operator* op = subgraph->operators()->Get(j);
864       OperatorCode* op_code =
865           model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
866 
867       auto builtin_code = GetBuiltinCode(op_code);
868       if (builtin_code != BuiltinOperator_CUSTOM) {
869         OpSignature op_sig = GetOpSignature(op_code, op, subgraph, model);
870         // Update builtin operator version.
871         int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
872         if (op_sig.builtin_data) {
873           free(op_sig.builtin_data);
874         }
875         // Skip updating op version if the current node uses lower version.
876         // TODO(b/184366869): Populate multiple versions of operator once MLIR
877         // quantizer is ready.
878         if (op_ver <= op_code->version()) {
879           continue;
880         }
881         if (!op_code->mutate_version(op_ver)) {
882           LOG(ERROR) << "Can't set operator "
883                      << EnumNameBuiltinOperator(builtin_code) << " to version "
884                      << op_ver;
885         }
886       }
887     }
888   }
889 }
890 
891 }  // namespace tflite
892