1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/node_def.pb.h"
17 #include "tensorflow/core/lib/strings/str_util.h"
18 #include "tensorflow/core/platform/env.h"
19 #include "tensorflow/tools/graph_transforms/transform_utils.h"
20
21 namespace tensorflow {
22 namespace graph_transforms {
23
24 struct MinMaxRecord {
25 string name;
26 float min;
27 float max;
28 };
29
30 // Try to parse a log file containing loosely-structured lines, some of which
31 // are the min/max logs we want.
ExtractMinMaxRecords(const string & log_file_name,std::vector<MinMaxRecord> * records)32 Status ExtractMinMaxRecords(const string& log_file_name,
33 std::vector<MinMaxRecord>* records) {
34 string file_data;
35 TF_RETURN_IF_ERROR(
36 ReadFileToString(Env::Default(), log_file_name, &file_data));
37 const string print_suffix("__print__");
38 const string requant_prefix("__requant_min_max:");
39 std::vector<string> file_lines = str_util::Split(file_data, '\n');
40 for (const string& file_line : file_lines) {
41 // We expect to find a line with components separated by semicolons, so to
42 // start make sure that the basic structure is in place/
43 if (!str_util::StrContains(file_line,
44 print_suffix + ";" + requant_prefix)) {
45 continue;
46 }
47 std::vector<string> line_parts = str_util::Split(file_line, ';');
48 if (line_parts.size() < 2) {
49 continue;
50 }
51 // Now we want to figure out which components have the name and min max
52 // values by scanning for the prefix we expect.
53 bool min_max_found = false;
54 int min_max_index;
55 for (int i = 1; i < line_parts.size(); ++i) {
56 if (str_util::StartsWith(line_parts[i], requant_prefix)) {
57 min_max_found = true;
58 min_max_index = i;
59 }
60 }
61 if (!min_max_found) {
62 continue;
63 }
64 // Finally we need to break out the values from the strings, and parse them
65 // into a form we can use.
66 string min_max_string = line_parts[min_max_index];
67 std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
68 if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
69 continue;
70 }
71 string min_string = min_max_parts[1];
72 std::vector<string> min_string_parts = str_util::Split(min_string, ']');
73 if (min_string_parts.size() != 2) {
74 continue;
75 }
76 string min_number_string = min_string_parts[0];
77 float min;
78 if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
79 continue;
80 }
81 string max_string = min_max_parts[2];
82 std::vector<string> max_string_parts = str_util::Split(max_string, ']');
83 if (max_string_parts.size() != 2) {
84 continue;
85 }
86 string max_number_string = max_string_parts[0];
87 float max;
88 if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
89 continue;
90 }
91 StringPiece name_string = line_parts[min_max_index - 1];
92 if (!str_util::EndsWith(name_string, print_suffix)) {
93 continue;
94 }
95 string name(
96 name_string.substr(0, name_string.size() - print_suffix.size()));
97 records->push_back({name, min, max});
98 }
99 return Status::OK();
100 }
101
102 // Uses the observed min/max values for requantization captured in a log file to
103 // replace costly RequantizationRange ops with simple Consts.
FreezeRequantizationRanges(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)104 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
105 const TransformFuncContext& context,
106 GraphDef* output_graph_def) {
107 string min_max_log_file;
108 TF_RETURN_IF_ERROR(
109 context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
110 if (min_max_log_file.empty()) {
111 return errors::InvalidArgument(
112 "You must pass a file name to min_max_log_file");
113 }
114 float min_percentile;
115 TF_RETURN_IF_ERROR(
116 context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
117 float max_percentile;
118 TF_RETURN_IF_ERROR(
119 context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
120
121 std::vector<MinMaxRecord> records;
122 TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
123 if (records.empty()) {
124 return errors::InvalidArgument(
125 "No min/max range logs were found in the log file");
126 }
127
128 std::map<string, const NodeDef*> node_map;
129 MapNamesToNodes(input_graph_def, &node_map);
130 bool any_missing_nodes = false;
131 std::map<string, std::vector<MinMaxRecord>> records_by_node;
132 for (const MinMaxRecord& record : records) {
133 records_by_node[record.name].push_back(record);
134 if (!node_map.count(record.name)) {
135 any_missing_nodes = true;
136 LOG(WARNING) << "Node from log not found in graph: " << record.name;
137 }
138 }
139 if (any_missing_nodes) {
140 return errors::InvalidArgument(
141 "Nodes were found in the log file that aren't present in the graph");
142 }
143
144 // Now find out the largest and smallest min/max values for the node.
145 std::map<string, std::pair<float, float>> range_for_nodes;
146 for (const auto& record_info : records_by_node) {
147 const string& name = record_info.first;
148 const std::vector<MinMaxRecord> records = record_info.second;
149 std::vector<float> mins;
150 std::vector<float> maxs;
151 for (const MinMaxRecord& record : records) {
152 mins.push_back(record.min);
153 maxs.push_back(record.max);
154 }
155 std::sort(mins.begin(), mins.end());
156 std::sort(maxs.begin(), maxs.end());
157 int min_index = std::round(mins.size() * (min_percentile / 100.0f));
158 if (min_index < 0) {
159 min_index = 0;
160 }
161 int max_index =
162 std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
163 if (max_index > (maxs.size() - 1)) {
164 max_index = maxs.size() - 1;
165 }
166 const float min = mins[min_index];
167 const float max = maxs[max_index];
168 range_for_nodes[name] = {min, max};
169 }
170 std::map<string, string> inputs_to_rename;
171 GraphDef frozen_graph_def;
172 for (const NodeDef& node : input_graph_def.node()) {
173 if (range_for_nodes.count(node.name())) {
174 if (node.op() != "RequantizationRange") {
175 return errors::InvalidArgument(
176 "Node is expected to be a RequantizationRange op: ", node.name(),
177 ", but is: ", node.op());
178 }
179 const float min_value = range_for_nodes.at(node.name()).first;
180 NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
181 min_node->set_op("Const");
182 min_node->set_name(node.name() + "/frozen_min");
183 SetNodeAttr("dtype", DT_FLOAT, min_node);
184 Tensor min_tensor(DT_FLOAT, {});
185 min_tensor.flat<float>()(0) = min_value;
186 SetNodeTensorAttr<float>("value", min_tensor, min_node);
187 inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
188
189 const float max_value = range_for_nodes.at(node.name()).second;
190 NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
191 max_node->set_op("Const");
192 max_node->set_name(node.name() + "/frozen_max");
193 SetNodeAttr("dtype", DT_FLOAT, max_node);
194 Tensor max_tensor(DT_FLOAT, {});
195 max_tensor.flat<float>()(0) = max_value;
196 SetNodeTensorAttr<float>("value", max_tensor, max_node);
197 inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
198 } else {
199 NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
200 *new_node = node;
201 }
202 }
203 return RenameNodeInputs(frozen_graph_def, inputs_to_rename,
204 std::unordered_set<string>(), output_graph_def);
205 }
206
207 REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
208 FreezeRequantizationRanges);
209
210 } // namespace graph_transforms
211 } // namespace tensorflow
212