• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_
16 #define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_
17 
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "tensorflow/lite/core/api/error_reporter.h"
25 #include "tensorflow/lite/model.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace mlir {
29 namespace lite {
30 
31 // Supported resulting types from quantization process.
32 enum class BufferType { QUANTIZED_INT8, QUANTIZED_FLOAT16 };
33 
34 // Stores information about how to quantize a user-specified custom operation.
35 // CustomOpInfo contains info of its corresponding CustomOp registered in the
36 // CustomOpMap. 'quantizable_input_indices' is used to determine which indices
37 // of the CustomOp are quantizable. 'is_weight_only' is used specify whether the
38 // custom op is quantized only for storage and dequantized at runtime.
39 // 'no_side_effect' is used to determine whether the op can be pruned if
40 // considered as trivially dead.
41 struct CustomOpInfo {
42   std::vector<std::int32_t> quantizable_input_indices;
43   bool is_weight_only = false;
44   bool no_side_effect = true;
45 };
46 
47 using StringSet = absl::flat_hash_set<std::string>;
48 using BuiltinOperatorSet = absl::flat_hash_set<tflite::BuiltinOperator>;
49 // Map from custom op code to custom op quantization information.
50 using CustomOpMap = std::unordered_map<std::string, CustomOpInfo>;
51 
52 // Applies dynamic range quantization for the given model wehre the input_model
53 // type is flatbuffer but is converted to MLIR during quantization process and
54 // then converted back to flatbuffer for return. Note that this is part of
55 // reaching feature parity with the old quantizer for dynamic range
56 // quantization, specifically for
57 // third_party/tensorflow/lite/tools/optimize/quantize_weights.h.
58 // TODO(b/202468183): Selective quantization + quant debugger support for
59 // dynamic range quantization for verify_numeric and whole_model_verify flags.
60 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
61                              const tflite::Model* input_model,
62                              tflite::ErrorReporter* error_reporter,
63                              const tflite::TensorType& inference_type,
64                              const StringSet& denylisted_ops,
65                              const CustomOpMap& custom_op_map,
66                              int64_t minimum_elements_for_weights = 1024,
67                              bool disable_per_channel = false,
68                              bool weight_only_quantization = false,
69                              bool legacy_float_scale = false);
70 
71 // Overloading methods to support old quantizer versions API
72 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
73                              const tflite::Model* input_model,
74                              int64_t weights_min_num_elements,
75                              bool use_hybrid_evaluation = true);
76 
77 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
78                              const tflite::Model* input_model,
79                              BufferType quant_type = BufferType::QUANTIZED_INT8,
80                              bool use_updated_hybrid_scheme = true);
81 
82 TfLiteStatus QuantizeWeights(flatbuffers::FlatBufferBuilder* builder,
83                              const tflite::Model* input_model,
84                              int64_t weights_min_num_elements,
85                              const CustomOpMap& custom_op_map,
86                              bool use_updated_hybrid_scheme = true,
87                              const BuiltinOperatorSet& op_denylist = {});
88 
89 }  // namespace lite
90 }  // namespace mlir
91 
92 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_LITE_QUANTIZE_WEIGHTS_H_
93