• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def.pb.h"
17 #include "tensorflow/core/lib/strings/str_util.h"
18 #include "tensorflow/core/platform/env.h"
19 #include "tensorflow/tools/graph_transforms/transform_utils.h"
20 
21 namespace tensorflow {
22 namespace graph_transforms {
23 
24 struct MinMaxRecord {
25   string name;
26   float min;
27   float max;
28 };
29 
30 // Try to parse a log file containing loosely-structured lines, some of which
31 // are the min/max logs we want.
ExtractMinMaxRecords(const string & log_file_name,std::vector<MinMaxRecord> * records)32 Status ExtractMinMaxRecords(const string& log_file_name,
33                             std::vector<MinMaxRecord>* records) {
34   string file_data;
35   TF_RETURN_IF_ERROR(
36       ReadFileToString(Env::Default(), log_file_name, &file_data));
37   const string print_suffix("__print__");
38   const string requant_prefix("__requant_min_max:");
39   std::vector<string> file_lines = str_util::Split(file_data, '\n');
40   for (const string& file_line : file_lines) {
41     // We expect to find a line with components separated by semicolons, so to
42     // start make sure that the basic structure is in place/
43     if (!str_util::StrContains(file_line,
44                                print_suffix + ";" + requant_prefix)) {
45       continue;
46     }
47     std::vector<string> line_parts = str_util::Split(file_line, ';');
48     if (line_parts.size() < 2) {
49       continue;
50     }
51     // Now we want to figure out which components have the name and min max
52     // values by scanning for the prefix we expect.
53     bool min_max_found = false;
54     int min_max_index;
55     for (int i = 1; i < line_parts.size(); ++i) {
56       if (str_util::StartsWith(line_parts[i], requant_prefix)) {
57         min_max_found = true;
58         min_max_index = i;
59       }
60     }
61     if (!min_max_found) {
62       continue;
63     }
64     // Finally we need to break out the values from the strings, and parse them
65     // into a form we can use.
66     string min_max_string = line_parts[min_max_index];
67     std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
68     if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
69       continue;
70     }
71     string min_string = min_max_parts[1];
72     std::vector<string> min_string_parts = str_util::Split(min_string, ']');
73     if (min_string_parts.size() != 2) {
74       continue;
75     }
76     string min_number_string = min_string_parts[0];
77     float min;
78     if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
79       continue;
80     }
81     string max_string = min_max_parts[2];
82     std::vector<string> max_string_parts = str_util::Split(max_string, ']');
83     if (max_string_parts.size() != 2) {
84       continue;
85     }
86     string max_number_string = max_string_parts[0];
87     float max;
88     if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
89       continue;
90     }
91     StringPiece name_string = line_parts[min_max_index - 1];
92     if (!str_util::EndsWith(name_string, print_suffix)) {
93       continue;
94     }
95     string name(
96         name_string.substr(0, name_string.size() - print_suffix.size()));
97     records->push_back({name, min, max});
98   }
99   return Status::OK();
100 }
101 
102 // Uses the observed min/max values for requantization captured in a log file to
103 // replace costly RequantizationRange ops with simple Consts.
FreezeRequantizationRanges(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)104 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
105                                   const TransformFuncContext& context,
106                                   GraphDef* output_graph_def) {
107   string min_max_log_file;
108   TF_RETURN_IF_ERROR(
109       context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
110   if (min_max_log_file.empty()) {
111     return errors::InvalidArgument(
112         "You must pass a file name to min_max_log_file");
113   }
114   float min_percentile;
115   TF_RETURN_IF_ERROR(
116       context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
117   float max_percentile;
118   TF_RETURN_IF_ERROR(
119       context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
120 
121   std::vector<MinMaxRecord> records;
122   TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
123   if (records.empty()) {
124     return errors::InvalidArgument(
125         "No min/max range logs were found in the log file");
126   }
127 
128   std::map<string, const NodeDef*> node_map;
129   MapNamesToNodes(input_graph_def, &node_map);
130   bool any_missing_nodes = false;
131   std::map<string, std::vector<MinMaxRecord>> records_by_node;
132   for (const MinMaxRecord& record : records) {
133     records_by_node[record.name].push_back(record);
134     if (!node_map.count(record.name)) {
135       any_missing_nodes = true;
136       LOG(WARNING) << "Node from log not found in graph: " << record.name;
137     }
138   }
139   if (any_missing_nodes) {
140     return errors::InvalidArgument(
141         "Nodes were found in the log file that aren't present in the graph");
142   }
143 
144   // Now find out the largest and smallest min/max values for the node.
145   std::map<string, std::pair<float, float>> range_for_nodes;
146   for (const auto& record_info : records_by_node) {
147     const string& name = record_info.first;
148     const std::vector<MinMaxRecord> records = record_info.second;
149     std::vector<float> mins;
150     std::vector<float> maxs;
151     for (const MinMaxRecord& record : records) {
152       mins.push_back(record.min);
153       maxs.push_back(record.max);
154     }
155     std::sort(mins.begin(), mins.end());
156     std::sort(maxs.begin(), maxs.end());
157     int min_index = std::round(mins.size() * (min_percentile / 100.0f));
158     if (min_index < 0) {
159       min_index = 0;
160     }
161     int max_index =
162         std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
163     if (max_index > (maxs.size() - 1)) {
164       max_index = maxs.size() - 1;
165     }
166     const float min = mins[min_index];
167     const float max = maxs[max_index];
168     range_for_nodes[name] = {min, max};
169   }
170   std::map<string, string> inputs_to_rename;
171   GraphDef frozen_graph_def;
172   for (const NodeDef& node : input_graph_def.node()) {
173     if (range_for_nodes.count(node.name())) {
174       if (node.op() != "RequantizationRange") {
175         return errors::InvalidArgument(
176             "Node is expected to be a RequantizationRange op: ", node.name(),
177             ", but is: ", node.op());
178       }
179       const float min_value = range_for_nodes.at(node.name()).first;
180       NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
181       min_node->set_op("Const");
182       min_node->set_name(node.name() + "/frozen_min");
183       SetNodeAttr("dtype", DT_FLOAT, min_node);
184       Tensor min_tensor(DT_FLOAT, {});
185       min_tensor.flat<float>()(0) = min_value;
186       SetNodeTensorAttr<float>("value", min_tensor, min_node);
187       inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
188 
189       const float max_value = range_for_nodes.at(node.name()).second;
190       NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
191       max_node->set_op("Const");
192       max_node->set_name(node.name() + "/frozen_max");
193       SetNodeAttr("dtype", DT_FLOAT, max_node);
194       Tensor max_tensor(DT_FLOAT, {});
195       max_tensor.flat<float>()(0) = max_value;
196       SetNodeTensorAttr<float>("value", max_tensor, max_node);
197       inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
198     } else {
199       NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
200       *new_node = node;
201     }
202   }
203   return RenameNodeInputs(frozen_graph_def, inputs_to_rename,
204                           std::unordered_set<string>(), output_graph_def);
205 }
206 
207 REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
208                          FreezeRequantizationRanges);
209 
210 }  // namespace graph_transforms
211 }  // namespace tensorflow
212