• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/node_def.pb.h"
17 #include "tensorflow/core/lib/strings/str_util.h"
18 #include "tensorflow/core/platform/env.h"
19 #include "tensorflow/tools/graph_transforms/transform_utils.h"
20 
21 namespace tensorflow {
22 namespace graph_transforms {
23 
24 struct MinMaxRecord {
25   string name;
26   float min;
27   float max;
28 };
29 
30 // Try to parse a log file containing loosely-structured lines, some of which
31 // are the min/max logs we want.
ExtractMinMaxRecords(const string & log_file_name,std::vector<MinMaxRecord> * records)32 Status ExtractMinMaxRecords(const string& log_file_name,
33                             std::vector<MinMaxRecord>* records) {
34   string file_data;
35   TF_RETURN_IF_ERROR(
36       ReadFileToString(Env::Default(), log_file_name, &file_data));
37   const string print_suffix("__print__");
38   const string requant_prefix("__requant_min_max:");
39   std::vector<string> file_lines = str_util::Split(file_data, '\n');
40   for (const string& file_line : file_lines) {
41     // We expect to find a line with components separated by semicolons, so to
42     // start make sure that the basic structure is in place/
43     if (!absl::StrContains(file_line, print_suffix + ";" + requant_prefix)) {
44       continue;
45     }
46     std::vector<string> line_parts = str_util::Split(file_line, ';');
47     if (line_parts.size() < 2) {
48       continue;
49     }
50     // Now we want to figure out which components have the name and min max
51     // values by scanning for the prefix we expect.
52     bool min_max_found = false;
53     int min_max_index;
54     for (int i = 1; i < line_parts.size(); ++i) {
55       if (absl::StartsWith(line_parts[i], requant_prefix)) {
56         min_max_found = true;
57         min_max_index = i;
58       }
59     }
60     if (!min_max_found) {
61       continue;
62     }
63     // Finally we need to break out the values from the strings, and parse them
64     // into a form we can use.
65     string min_max_string = line_parts[min_max_index];
66     std::vector<string> min_max_parts = str_util::Split(min_max_string, '[');
67     if ((min_max_parts.size() != 3) || (min_max_parts[0] != requant_prefix)) {
68       continue;
69     }
70     string min_string = min_max_parts[1];
71     std::vector<string> min_string_parts = str_util::Split(min_string, ']');
72     if (min_string_parts.size() != 2) {
73       continue;
74     }
75     string min_number_string = min_string_parts[0];
76     float min;
77     if (!strings::safe_strtof(min_number_string.c_str(), &min)) {
78       continue;
79     }
80     string max_string = min_max_parts[2];
81     std::vector<string> max_string_parts = str_util::Split(max_string, ']');
82     if (max_string_parts.size() != 2) {
83       continue;
84     }
85     string max_number_string = max_string_parts[0];
86     float max;
87     if (!strings::safe_strtof(max_number_string.c_str(), &max)) {
88       continue;
89     }
90     StringPiece name_string = line_parts[min_max_index - 1];
91     if (!str_util::EndsWith(name_string, print_suffix)) {
92       continue;
93     }
94     string name(
95         name_string.substr(0, name_string.size() - print_suffix.size()));
96     records->push_back({name, min, max});
97   }
98   return Status::OK();
99 }
100 
101 // Uses the observed min/max values for requantization captured in a log file to
102 // replace costly RequantizationRange ops with simple Consts.
FreezeRequantizationRanges(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)103 Status FreezeRequantizationRanges(const GraphDef& input_graph_def,
104                                   const TransformFuncContext& context,
105                                   GraphDef* output_graph_def) {
106   string min_max_log_file;
107   TF_RETURN_IF_ERROR(
108       context.GetOneStringParameter("min_max_log_file", "", &min_max_log_file));
109   if (min_max_log_file.empty()) {
110     return errors::InvalidArgument(
111         "You must pass a file name to min_max_log_file");
112   }
113   float min_percentile;
114   TF_RETURN_IF_ERROR(
115       context.GetOneFloatParameter("min_percentile", 5.0f, &min_percentile));
116   float max_percentile;
117   TF_RETURN_IF_ERROR(
118       context.GetOneFloatParameter("max_percentile", 5.0f, &max_percentile));
119 
120   std::vector<MinMaxRecord> records;
121   TF_RETURN_IF_ERROR(ExtractMinMaxRecords(min_max_log_file, &records));
122   if (records.empty()) {
123     return errors::InvalidArgument(
124         "No min/max range logs were found in the log file");
125   }
126 
127   std::map<string, const NodeDef*> node_map;
128   MapNamesToNodes(input_graph_def, &node_map);
129   bool any_missing_nodes = false;
130   std::map<string, std::vector<MinMaxRecord>> records_by_node;
131   for (const MinMaxRecord& record : records) {
132     records_by_node[record.name].push_back(record);
133     if (!node_map.count(record.name)) {
134       any_missing_nodes = true;
135       LOG(WARNING) << "Node from log not found in graph: " << record.name;
136     }
137   }
138   if (any_missing_nodes) {
139     return errors::InvalidArgument(
140         "Nodes were found in the log file that aren't present in the graph");
141   }
142 
143   // Now find out the largest and smallest min/max values for the node.
144   std::map<string, std::pair<float, float>> range_for_nodes;
145   for (const auto& record_info : records_by_node) {
146     const string& name = record_info.first;
147     const std::vector<MinMaxRecord> records = record_info.second;
148     std::vector<float> mins;
149     std::vector<float> maxs;
150     for (const MinMaxRecord& record : records) {
151       mins.push_back(record.min);
152       maxs.push_back(record.max);
153     }
154     std::sort(mins.begin(), mins.end());
155     std::sort(maxs.begin(), maxs.end());
156     int min_index = std::round(mins.size() * (min_percentile / 100.0f));
157     if (min_index < 0) {
158       min_index = 0;
159     }
160     int max_index =
161         std::round(maxs.size() * (1.0f - (max_percentile / 100.0f)));
162     if (max_index > (maxs.size() - 1)) {
163       max_index = maxs.size() - 1;
164     }
165     const float min = mins[min_index];
166     const float max = maxs[max_index];
167     range_for_nodes[name] = {min, max};
168   }
169   std::map<string, string> inputs_to_rename;
170   GraphDef frozen_graph_def;
171   for (const NodeDef& node : input_graph_def.node()) {
172     if (range_for_nodes.count(node.name())) {
173       if (node.op() != "RequantizationRange") {
174         return errors::InvalidArgument(
175             "Node is expected to be a RequantizationRange op: ", node.name(),
176             ", but is: ", node.op());
177       }
178       const float min_value = range_for_nodes.at(node.name()).first;
179       NodeDef* min_node = frozen_graph_def.mutable_node()->Add();
180       min_node->set_op("Const");
181       min_node->set_name(node.name() + "/frozen_min");
182       SetNodeAttr("dtype", DT_FLOAT, min_node);
183       Tensor min_tensor(DT_FLOAT, {});
184       min_tensor.flat<float>()(0) = min_value;
185       SetNodeTensorAttr<float>("value", min_tensor, min_node);
186       inputs_to_rename[node.name() + ":0"] = min_node->name() + ":0";
187 
188       const float max_value = range_for_nodes.at(node.name()).second;
189       NodeDef* max_node = frozen_graph_def.mutable_node()->Add();
190       max_node->set_op("Const");
191       max_node->set_name(node.name() + "/frozen_max");
192       SetNodeAttr("dtype", DT_FLOAT, max_node);
193       Tensor max_tensor(DT_FLOAT, {});
194       max_tensor.flat<float>()(0) = max_value;
195       SetNodeTensorAttr<float>("value", max_tensor, max_node);
196       inputs_to_rename[node.name() + ":1"] = max_node->name() + ":0";
197     } else {
198       NodeDef* new_node = frozen_graph_def.mutable_node()->Add();
199       *new_node = node;
200     }
201   }
202   return RenameNodeInputs(frozen_graph_def, inputs_to_rename,
203                           std::unordered_set<string>(), output_graph_def);
204 }
205 
206 REGISTER_GRAPH_TRANSFORM("freeze_requantization_ranges",
207                          FreezeRequantizationRanges);
208 
209 }  // namespace graph_transforms
210 }  // namespace tensorflow
211