• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/versioning/op_version.h"
16 
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/lite/kernels/internal/compatibility.h"
26 
27 namespace tflite {
28 
GetBuiltinOperatorVersion(const OpSignature & op_sig)29 int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
30   switch (op_sig.op) {
31     case BuiltinOperator_CONV_2D:
32       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
33       // version 3.
34       if (op_sig.input_types.at(0) == TensorType_INT8 &&
35           op_sig.input_types.at(1) == TensorType_INT8 &&
36           op_sig.output_types.at(0) == TensorType_INT8) {
37         return 3;
38       }
39       // If the op is a signed int8 hybrid operation, we need to return
40       // version 2.
41       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
42           op_sig.input_types.at(1) == TensorType_INT8 &&
43           op_sig.output_types.at(0) == TensorType_FLOAT32) {
44         return 2;
45       }
46       return 1;
47 
48     case BuiltinOperator_DEPTHWISE_CONV_2D:
49       // If the op is a signed int8 hybrid operation, we need to return
50       // version 4.
51       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
52           op_sig.input_types.at(1) == TensorType_INT8 &&
53           op_sig.output_types.at(0) == TensorType_FLOAT32) {
54         return 4;
55       }
56       // If the op has signed int8 op_sig.inputs and op_sig.outputs, its
57       // version 3.
58       if (op_sig.input_types.at(0) == TensorType_INT8 &&
59           op_sig.input_types.at(1) == TensorType_INT8 &&
60           op_sig.output_types.at(0) == TensorType_INT8) {
61         return 3;
62       }
63       if (op_sig.options.depthwise_conv_2d.dilation_w_factor != 1 ||
64           op_sig.options.depthwise_conv_2d.dilation_h_factor != 1) {
65         return 2;
66       }
67       return 1;
68 
69     case BuiltinOperator_FAKE_QUANT:
70       if (op_sig.options.fakequant.narrow_range) {
71         return 2;
72       }
73       return 1;
74 
75     case BuiltinOperator_FULLY_CONNECTED:
76       // +-----------------+--------------------+--------------------------+
77       // |                 |    Weight::Default | Weight::Shuffled4x16Int8 |
78       // +-----------------+--------------------+--------------------------+
79       // | Float           |                  1 |                        2 |
80       // | Quantized Uint8 |                  1 |                        2 |
81       // | Hybrid          |                  3 |                        3 |
82       // | Quantized Int8  |                  4 |                        4 |
83       // +-----------------+--------------------+--------------------------+
84       // 2 op_sig.inputs (no bias) use case is supported starting from
85       // version 6.
86       if (op_sig.input_types.size() == 2) {
87         return 6;
88       }
89       // `keep_num_dims` is supported at verison 5.
90       if (op_sig.options.fully_connected.keep_num_dims) {
91         return 5;
92       }
93       // Int8 fully fixed point kernel is at version 4.
94       if (op_sig.input_types.at(0) == TensorType_INT8 &&
95           op_sig.input_types.at(1) == TensorType_INT8 &&
96           op_sig.output_types.at(0) == TensorType_INT8) {
97         return 4;
98       }
99       // If the op is a signed int8 hybrid operation, we need to return
100       // version 3.
101       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
102           op_sig.input_types.at(1) == TensorType_INT8 &&
103           op_sig.output_types.at(0) == TensorType_FLOAT32) {
104         return 3;
105       }
106       // For float and uint8 fixed point kernels, if the weight is
107       // Shuffled4x16Int8, is is version 2.
108       if (op_sig.options.fully_connected.weights_format ==
109           FullyConnectedOptionsWeightsFormat_SHUFFLED4x16INT8) {
110         return 2;
111       }
112       // Otherwise (weight is default), the version is 1.
113       return 1;
114 
115     case BuiltinOperator_GATHER:
116       // If the op takes bool input, it is version 3.
117       if (op_sig.input_types.at(0) == TensorType_BOOL) {
118         return 3;
119       }
120       if (op_sig.input_types.at(0) == TensorType_INT8) {
121         return 2;
122       }
123       return 1;
124 
125     case BuiltinOperator_SVDF:
126       // Fully integer SVDF has int8 as input and is of version 3.
127       if (op_sig.input_types.at(0) == TensorType_INT8) {
128         return 3;
129       }
130       // If the op is a signed int8 hybrid operation, we need to return
131       // version 2.
132       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
133           op_sig.input_types.at(1) == TensorType_INT8 &&
134           op_sig.output_types.at(0) == TensorType_FLOAT32) {
135         return 2;
136       }
137       return 1;
138 
139     case BuiltinOperator_MUL:
140       // Version 3 supports have a rescale value greater than or equal to 1.
141       if (op_sig.options.mul.input1_scale != 0 &&
142           op_sig.options.mul.input2_scale != 0 &&
143           op_sig.options.mul.output_scale != 0 &&
144           (op_sig.options.mul.input1_scale * op_sig.options.mul.input2_scale /
145            op_sig.options.mul.output_scale) >= 1.0) {
146         return 3;
147       }
148       if (op_sig.input_types.at(0) == TensorType_INT8) {
149         return 2;
150       }
151       return 1;
152 
153     case BuiltinOperator_TRANSPOSE:
154       // If the op takes bool input, it is version 3.
155       if (op_sig.input_types.at(0) == TensorType_BOOL) {
156         return 3;
157       }
158       if (op_sig.input_types.at(0) == TensorType_INT8) {
159         return 2;
160       }
161       return 1;
162 
163     case BuiltinOperator_TRANSPOSE_CONV:
164       // If the op takes int8 input, it is version 2.
165       if (op_sig.input_types.at(0) == TensorType_INT8) {
166         return 2;
167       }
168       return 1;
169 
170     case BuiltinOperator_LSTM:
171       // If the input tensor is float and a weight is int8, this is a version
172       // 3 hybrid operation.
173       if (op_sig.options.lstm.kernel_type == LSTMKernelType_FULL &&
174           op_sig.input_types.at(0) == TensorType_FLOAT32 &&
175           op_sig.input_types.at(2) == TensorType_INT8 &&
176           op_sig.output_types.at(0) == TensorType_FLOAT32) {
177         return 3;
178       }
179       // KERNEL_BASIC was added in version 2.
180       if (op_sig.options.lstm.kernel_type == LSTMKernelType_BASIC) {
181         return 2;
182       }
183       return 1;
184 
185     case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
186       // If the input tensor is float and a weight is int8, this is a version
187       // 2 hybrid operation.
188       if (op_sig.input_types.at(0) == TensorType_FLOAT32 &&
189           op_sig.input_types.at(2) == TensorType_INT8 &&
190           op_sig.output_types.at(0) == TensorType_FLOAT32) {
191         return 2;
192       }
193       return 1;
194 
195     case BuiltinOperator_SPLIT:
196       // If the op take int8 input, it is version 2, for int32 it's version 3.
197       if (op_sig.input_types.at(0) == TensorType_INT32) {
198         return 3;
199       }
200       if (op_sig.input_types.at(0) == TensorType_INT8) {
201         return 2;
202       }
203       return 1;
204 
205     case BuiltinOperator_SPARSE_TO_DENSE:
206       // Version 3 supports Int8 and Uint8 type.
207       if (op_sig.input_types.at(2) == TensorType_INT8 ||
208           op_sig.input_types.at(2) == TensorType_UINT8) {
209         return 3;
210       }
211       // Version 2 supports Int64 value type.
212       if (op_sig.input_types.at(2) == TensorType_INT64) {
213         return 2;
214       }
215       return 1;
216 
217     case BuiltinOperator_SLICE:
218       // Version 3 supports string input types.
219       if (op_sig.input_types.at(0) == TensorType_STRING) {
220         return 3;
221       }
222       if (op_sig.input_types.at(0) == TensorType_INT8) {
223         return 2;
224       }
225       return 1;
226 
227     case BuiltinOperator_UNPACK:
228       // If the op take int8/uint8 input, it is version 2.
229       if (op_sig.input_types.at(0) == TensorType_INT8 ||
230           op_sig.input_types.at(0) == TensorType_UINT8) {
231         return 2;
232       }
233       // If the op take bool input, it is version 3.
234       if (op_sig.input_types.at(0) == TensorType_BOOL) {
235         return 3;
236       }
237       return 1;
238 
239     case BuiltinOperator_DEQUANTIZE:
240       // Version 3 supports signed int16 input types.
241       if (op_sig.input_types.at(0) == TensorType_INT16 ||
242           op_sig.input_types.at(0) == TensorType_FLOAT16) {
243         return 3;
244       }
245       if (op_sig.input_types.at(0) == TensorType_INT8) {
246         return 2;
247       }
248       return 1;
249 
250     case BuiltinOperator_FLOOR_DIV:
251       if (op_sig.input_types.at(0) == TensorType_FLOAT32) {
252         return 2;
253       }
254       return 1;
255 
256     case BuiltinOperator_L2_NORMALIZATION:
257       if (op_sig.output_types.at(0) == TensorType_INT8) {
258         return 2;
259       }
260       return 1;
261 
262     case BuiltinOperator_RELU:
263       if (op_sig.input_types.at(0) == TensorType_INT8 ||
264           op_sig.input_types.at(0) == TensorType_UINT8) {
265         return 2;
266       }
267       return 1;
268     case BuiltinOperator_STRIDED_SLICE:
269       // If the op takes bool input, it is version 3.
270       if (op_sig.input_types.at(0) == TensorType_BOOL) {
271         return 3;
272       }
273       if (op_sig.input_types.at(0) == TensorType_INT8) {
274         return 2;
275       }
276       return 1;
277     case BuiltinOperator_REVERSE_V2:
278       if (op_sig.input_types.at(0) == TensorType_BOOL) {
279         return 2;
280       }
281       return 1;
282     case BuiltinOperator_RESIZE_BILINEAR:
283       if (op_sig.options.resize_bilinear.half_pixel_centers) {
284         return 3;
285       } else if (op_sig.input_types.at(0) == TensorType_INT8) {
286         return 2;
287       }
288       return 1;
289 
290     case BuiltinOperator_AVERAGE_POOL_2D:
291     case BuiltinOperator_ADD:
292     case BuiltinOperator_SPACE_TO_BATCH_ND:
293     case BuiltinOperator_SUB:
294     case BuiltinOperator_BATCH_TO_SPACE_ND:
295     case BuiltinOperator_CONCATENATION:
296     case BuiltinOperator_MAX_POOL_2D:
297     case BuiltinOperator_MAXIMUM:
298     case BuiltinOperator_MINIMUM:
299     case BuiltinOperator_PAD:
300     case BuiltinOperator_PADV2:
301     case BuiltinOperator_SOFTMAX:
302     case BuiltinOperator_SPACE_TO_DEPTH:
303     case BuiltinOperator_MEAN:
304     case BuiltinOperator_SUM:
305     case BuiltinOperator_REDUCE_MAX:
306     case BuiltinOperator_REDUCE_MIN:
307     case BuiltinOperator_RELU6:
308     case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR:
309     case BuiltinOperator_PACK:
310     case BuiltinOperator_TANH:
311     case BuiltinOperator_LOGISTIC:
312     case BuiltinOperator_LOG_SOFTMAX:
313     case BuiltinOperator_TOPK_V2:
314     case BuiltinOperator_ARG_MAX:
315     case BuiltinOperator_ARG_MIN:
316     case BuiltinOperator_EQUAL:
317     case BuiltinOperator_NOT_EQUAL:
318     case BuiltinOperator_GREATER:
319     case BuiltinOperator_GREATER_EQUAL:
320     case BuiltinOperator_LESS:
321     case BuiltinOperator_LESS_EQUAL:
322     case BuiltinOperator_SELECT:
323       if (op_sig.input_types.at(0) == TensorType_INT8) {
324         return 2;
325       }
326       return 1;
327 
328     default:
329       return 1;
330   }
331 }
332 
GetTensorType(int32_t idx,const SubGraph * subgraph)333 TensorType GetTensorType(int32_t idx, const SubGraph* subgraph) {
334   const auto& none_type = static_cast<::tflite::TensorType>(-1);
335   if (idx == -1)
336     // For optional input/output, return none type directly.
337     return none_type;
338 
339   // Some tests have a graph with invalid tensor index.
340   TFLITE_DCHECK_GE(idx, 0);
341   if (subgraph->tensors() && idx < subgraph->tensors()->Length()) {
342     return subgraph->tensors()->Get(idx)->type();
343   }
344   LOG(ERROR) << "Can't access tenor " << idx;
345   return none_type;
346 }
347 
348 // Generate OpSignature with the given OperatorCode, Operator and Tensors (from
349 // SubGraph). The OpSignature will be used by GetBuiltinOperatorVersion() and
350 // mostly input and output tensor types are enough to figure out op version.
351 // But some ops (DEPTHWISE_CONV_2D,  FULLY_CONNECTED, ...) require to pass their
352 // options to decide op version.
GetOpSignature(const OperatorCode * op_code,const Operator * op,const SubGraph * subgraph)353 OpSignature GetOpSignature(const OperatorCode* op_code, const Operator* op,
354                            const SubGraph* subgraph) {
355   OpSignature op_sig = {op_code->builtin_code()};
356 
357   switch (op_code->builtin_code()) {
358     case BuiltinOperator_DEPTHWISE_CONV_2D: {
359       auto conv_option = op->builtin_options_as_DepthwiseConv2DOptions();
360       if (conv_option) {
361         op_sig.options.depthwise_conv_2d.dilation_w_factor =
362             conv_option->dilation_w_factor();
363         op_sig.options.depthwise_conv_2d.dilation_h_factor =
364             conv_option->dilation_h_factor();
365       }
366     } break;
367 
368     case BuiltinOperator_FAKE_QUANT: {
369       auto fakequant_option = op->builtin_options_as_FakeQuantOptions();
370       if (fakequant_option) {
371         op_sig.options.fakequant.narrow_range =
372             fakequant_option->narrow_range();
373       }
374     } break;
375 
376     case BuiltinOperator_FULLY_CONNECTED: {
377       auto fully_connected_option =
378           op->builtin_options_as_FullyConnectedOptions();
379       if (fully_connected_option) {
380         op_sig.options.fully_connected.keep_num_dims =
381             fully_connected_option->keep_num_dims();
382         op_sig.options.fully_connected.weights_format =
383             fully_connected_option->weights_format();
384       }
385     } break;
386 
387     case BuiltinOperator_MUL: {
388       if (op->inputs()->Length() < 2 || op->outputs()->Length() < 1) {
389         break;
390       }
391       const Tensor* input1_tensor =
392           subgraph->tensors()->Get(op->inputs()->Get(0));
393       const Tensor* input2_tensor =
394           subgraph->tensors()->Get(op->inputs()->Get(1));
395       const Tensor* output_tensor =
396           subgraph->tensors()->Get(op->outputs()->Get(0));
397       const QuantizationParameters* input1_quant =
398           input1_tensor->quantization();
399       const QuantizationParameters* input2_qunt = input2_tensor->quantization();
400       const QuantizationParameters* output_quant =
401           output_tensor->quantization();
402       if (input1_quant && input1_quant->scale() &&
403           input1_quant->scale()->Length() && input2_qunt &&
404           input2_qunt->scale() && input2_qunt->scale()->Length() &&
405           output_quant && output_quant->scale() &&
406           output_quant->scale()->Length()) {
407         op_sig.options.mul.input1_scale = input1_quant->scale()->Get(0);
408         op_sig.options.mul.input2_scale = input2_qunt->scale()->Get(0);
409         op_sig.options.mul.output_scale = output_quant->scale()->Get(0);
410       }
411     } break;
412 
413     case BuiltinOperator_LSTM: {
414       auto lstm_option = op->builtin_options_as_LSTMOptions();
415       if (lstm_option) {
416         op_sig.options.lstm.kernel_type = lstm_option->kernel_type();
417       }
418     } break;
419 
420     case BuiltinOperator_RESIZE_BILINEAR: {
421       auto resize_bilinear_option =
422           op->builtin_options_as_ResizeBilinearOptions();
423       if (resize_bilinear_option) {
424         op_sig.options.resize_bilinear.half_pixel_centers =
425             resize_bilinear_option->half_pixel_centers();
426       }
427     } break;
428 
429     default:
430       break;
431   }
432 
433   for (int32_t i = 0; i < op->inputs()->Length(); ++i) {
434     TensorType tensor_type = GetTensorType(op->inputs()->Get(i), subgraph);
435     op_sig.input_types.push_back(tensor_type);
436   }
437   for (int32_t i = 0; i < op->outputs()->Length(); ++i) {
438     TensorType tensor_type = GetTensorType(op->outputs()->Get(i), subgraph);
439     op_sig.output_types.push_back(tensor_type);
440   }
441   return op_sig;
442 }
443 
UpdateOpVersion(uint8_t * model_buffer_pointer)444 void UpdateOpVersion(uint8_t* model_buffer_pointer) {
445   auto model = GetMutableModel(model_buffer_pointer);
446   auto subgraphs = model->subgraphs();
447 
448   for (int i = 0; i < subgraphs->Length(); ++i) {
449     const SubGraph* subgraph = subgraphs->Get(i);
450     for (int j = 0; j < subgraph->operators()->Length(); ++j) {
451       const Operator* op = subgraph->operators()->Get(j);
452       OperatorCode* op_code =
453           model->mutable_operator_codes()->GetMutableObject(op->opcode_index());
454 
455       if (op_code->builtin_code() != BuiltinOperator_CUSTOM) {
456         OpSignature op_sig = GetOpSignature(op_code, op, subgraph);
457         // Update builtin operator version.
458         int32_t op_ver = GetBuiltinOperatorVersion(op_sig);
459         if (!op_code->mutate_version(op_ver)) {
460           LOG(ERROR) << "Can't set operator "
461                      << EnumNameBuiltinOperator(op_code->builtin_code())
462                      << " to version " << op_ver;
463         }
464       }
465     }
466   }
467 }
468 
469 }  // namespace tflite
470