1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
17
18 #include "absl/strings/numbers.h"
19 #include "absl/strings/str_split.h"
20 #include "absl/strings/string_view.h"
21 #include "llvm/Support/Debug.h"
22 #include "llvm/Support/raw_ostream.h"
23 #include "tensorflow/core/framework/types.pb.h"
24
25 // Is this dtype a quantization type from TensorFlow.
IsQuantizationType(tensorflow::DataType dtype)26 static bool IsQuantizationType(tensorflow::DataType dtype) {
27 switch (dtype) {
28 case tensorflow::DT_QINT8:
29 case tensorflow::DT_QUINT8:
30 case tensorflow::DT_QINT16:
31 case tensorflow::DT_QUINT16:
32 case tensorflow::DT_QINT32:
33 return true;
34 default:
35 return false;
36 }
37 }
38
39 namespace mlir {
40 namespace TFL {
41
ParseInputNodeQuantSpecs(absl::string_view node_names,absl::string_view min_values,absl::string_view max_values,absl::string_view inference_type,QuantizationSpecs * quant_specs)42 bool ParseInputNodeQuantSpecs(absl::string_view node_names,
43 absl::string_view min_values,
44 absl::string_view max_values,
45 absl::string_view inference_type,
46 QuantizationSpecs* quant_specs) {
47 std::vector<std::string> input_nodes = absl::StrSplit(node_names, ',');
48 std::vector<llvm::Optional<double>> node_mins;
49 if (!min_values.empty()) {
50 std::vector<std::string> node_mins_str = absl::StrSplit(min_values, ',');
51 for (int i = 0, e = node_mins_str.size(); i < e; i++) {
52 double value;
53 if (!absl::SimpleAtod(node_mins_str[i], &value)) {
54 return true;
55 }
56 node_mins.push_back(value);
57 }
58 }
59
60 std::vector<llvm::Optional<double>> node_maxs;
61 if (!max_values.empty()) {
62 std::vector<std::string> node_maxs_str = absl::StrSplit(max_values, ',');
63 for (int i = 0, e = node_maxs_str.size(); i < e; i++) {
64 double value;
65 if (!absl::SimpleAtod(node_maxs_str[i], &value)) {
66 llvm::errs() << "Unexpected mins: " << node_maxs_str[i] << "\n";
67 return true;
68 }
69 node_maxs.push_back(value);
70 }
71 }
72
73 tensorflow::DataType final_type = tensorflow::DT_FLOAT;
74 if (!inference_type.empty() &&
75 !DataType_Parse(std::string(inference_type), &final_type)) {
76 return true;
77 }
78 return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type,
79 quant_specs);
80 }
81
GetInputNodeQuantSpecs(const std::vector<std::string> & node_names,const std::vector<llvm::Optional<double>> & node_mins,const std::vector<llvm::Optional<double>> & node_maxs,tensorflow::DataType inference_type,QuantizationSpecs * quant_specs)82 bool GetInputNodeQuantSpecs(
83 const std::vector<std::string>& node_names,
84 const std::vector<llvm::Optional<double>>& node_mins,
85 const std::vector<llvm::Optional<double>>& node_maxs,
86 tensorflow::DataType inference_type, QuantizationSpecs* quant_specs) {
87 quant_specs->inference_type = inference_type;
88
89 // If min/max are not specified, just return;
90 if (node_mins.empty() || node_maxs.empty()) return false;
91
92 // Otherwise make sure min/max has the same size as inputs.
93 if (IsQuantizationType(inference_type)) {
94 // min/max should have same size as inputs, or shouldn't be specified.
95 if (node_names.size() != node_mins.size() ||
96 node_names.size() != node_maxs.size()) {
97 return true;
98 }
99 for (int i = 0, e = node_names.size(); i != e; ++i) {
100 quant_specs->input_ranges.push_back({node_mins[i], node_maxs[i]});
101 }
102 return false;
103 }
104 if (!node_mins.empty()) {
105 llvm::dbgs() << "Ignored input_min_values.";
106 }
107 if (!node_maxs.empty()) {
108 llvm::dbgs() << "Ignored input_max_values.";
109 }
110 return false;
111 }
112
113 } // namespace TFL
114 } // namespace mlir
115