1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/common_runtime/constant_folding.h"
19 #include "tensorflow/core/common_runtime/threadpool_device.h"
20 #include "tensorflow/core/graph/graph_constructor.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/graph/subgraph.h"
23 #include "tensorflow/core/kernels/quantization_utils.h"
24 #include "tensorflow/core/platform/init_main.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27
28 namespace tensorflow {
29 namespace graph_transforms {
30
31 // Holds the information we need to translate from a float version of this op
32 // into the quantized equivalent.
33 struct QuantizedOpInfo {
34 // The name of the float op.
35 string float_name;
36 // Which attributes to copy directly over.
37 std::vector<string> attrs_to_copy;
38 // Extra data type attributes we need to set.
39 std::vector<std::pair<string, DataType>> dtypes_to_set;
40 // What depth of inputs the op can read in.
41 DataType input_bit_depth;
42 // The depth of the op's quantized outputs.
43 DataType output_bit_depth;
44 // Which inputs (e.g. shapes) aren't involved in the quantization process.
45 std::set<int32> unquantized_inputs;
46 // How the outputs are arranged, either
47 // [input0, input1, min0, max0, min1, max1] for contiguous, or
48 // [input0, input1, min0, min1, max0, max1] for separate.
49 // The separate order is needed because it's the only way to specify unknown
50 // numbers of inputs for ops like Concat.
51 enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
52 };
53
54 // Every op that has a quantized equivalent should be listed here, so that the
55 // conversion process can transform them.
GetQuantizedOpList()56 const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
57 static const std::vector<QuantizedOpInfo> op_list = {
58 {"Add",
59 {},
60 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
61 DT_QUINT8,
62 DT_QINT32,
63 {},
64 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
65 {"AvgPool",
66 {"ksize", "strides", "padding"},
67 {{"T", DT_QUINT8}},
68 DT_QUINT8,
69 DT_QUINT8,
70 {},
71 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
72 {"BiasAdd",
73 {},
74 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
75 DT_QUINT8,
76 DT_QINT32,
77 {},
78 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
79 {"Concat",
80 {"N"},
81 {{"T", DT_QUINT8}},
82 DT_QUINT8,
83 DT_QUINT8,
84 {0},
85 QuantizedOpInfo::SEPARATE_MIN_MAX},
86 {"Conv2D",
87 {"strides", "padding"},
88 {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
89 DT_QUINT8,
90 DT_QINT32,
91 {},
92 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
93 {"MatMul",
94 {"transpose_a", "transpose_b"},
95 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
96 DT_QUINT8,
97 DT_QINT32,
98 {},
99 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
100 {"MaxPool",
101 {"ksize", "strides", "padding"},
102 {{"T", DT_QUINT8}},
103 DT_QUINT8,
104 DT_QUINT8,
105 {},
106 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
107 {"Mul",
108 {},
109 {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
110 DT_QUINT8,
111 DT_QINT32,
112 {},
113 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
114 {"Relu",
115 {},
116 {{"Tinput", DT_QUINT8}},
117 DT_QUINT8,
118 DT_QUINT8,
119 {},
120 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
121 {"ResizeBilinear",
122 {"align_corners"},
123 {{"T", DT_QUINT8}},
124 DT_QUINT8,
125 DT_QUINT8,
126 {1},
127 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
128 {"Relu6",
129 {},
130 {{"Tinput", DT_QUINT8}},
131 DT_QUINT8,
132 DT_QUINT8,
133 {},
134 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
135 {"Reshape",
136 {},
137 {{"T", DT_QUINT8}},
138 DT_QUINT8,
139 DT_QUINT8,
140 {1},
141 QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
142 };
143 return op_list;
144 }
145
146 namespace {
147 // Replaces invalid characters in input names to get a unique node name.
UniqueNodeNameFromInput(const string & input_name)148 string UniqueNodeNameFromInput(const string& input_name) {
149 string prefix;
150 string node_name;
151 string suffix;
152 NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
153 string result;
154 if (prefix == "^") {
155 result += "__hat__";
156 }
157 result += node_name;
158 if (!suffix.empty()) {
159 result += "__port__" + suffix.substr(1, suffix.size() - 1);
160 }
161 return result;
162 }
163
164 // Pulls two float values from the named parameters, with a lot of checking.
ExtractRangeFromParams(const TransformFuncContext & context,const string & min_name,const string & max_name,float * min_value,float * max_value,bool * has_range)165 Status ExtractRangeFromParams(const TransformFuncContext& context,
166 const string& min_name, const string& max_name,
167 float* min_value, float* max_value,
168 bool* has_range) {
169 // See if we've been given quantized inputs with a known range.
170 const bool has_min = (context.params.count(min_name) != 0);
171 const bool has_max = (context.params.count(max_name) != 0);
172 *has_range = (has_min || has_max);
173 if (!*has_range) {
174 return Status::OK();
175 }
176 if (!has_min || !has_max) {
177 return errors::InvalidArgument("You must pass both ", min_name, " and ",
178 max_name, " into quantize_nodes");
179 }
180 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
181 TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
182 return Status::OK();
183 }
184
185 } // namespace
186
187 // Analyzes all the nodes in the graph to figure out which ones are duplicates
188 // apart from their names. This commonly includes identical Const nodes, but can
189 // also be simple operations that are repeated on multiple outputs of a
190 // particular node. The complexity is managed using a hash function that avoids
191 // the need for any O(n^2) algorithms when identifying duplicates.
MergeDuplicateNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)192 Status MergeDuplicateNodes(const GraphDef& input_graph_def,
193 const TransformFuncContext& context,
194 GraphDef* output_graph_def) {
195 // Make sure we can look up inputs and outputs quickly.
196 std::set<string> input_names(context.input_names.begin(),
197 context.input_names.end());
198 std::set<string> output_names(context.output_names.begin(),
199 context.output_names.end());
200 GraphDef current_graph_def = input_graph_def;
201 // Keep running the merging until no more duplicates are found.
202 bool any_duplicates_found;
203 do {
204 any_duplicates_found = false;
205 // First arrange all of the nodes by a hash of their contents.
206 std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
207 for (const NodeDef& node : current_graph_def.node()) {
208 NodeDef nameless_node = node;
209 // The name matters if it's being used as an input or output node,
210 // otherwise ignore it when looking for duplicates.
211 if (!input_names.count(node.name()) && !output_names.count(node.name())) {
212 nameless_node.set_name("");
213 }
214 const uint64 hash = HashNodeDef(nameless_node);
215 hashed_nodes[hash].push_back(&node);
216 }
217 // If we have multiple nodes with the same hash, then we know they're
218 // duplicates and can be removed, unless they're stateful.
219 std::map<string, string> inputs_to_rename;
220 GraphDef merged_graph_def;
221 for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info :
222 hashed_nodes) {
223 const std::vector<const NodeDef*>& hash_node_list =
224 hashed_node_info.second;
225 for (int i = 0; i < hash_node_list.size(); ++i) {
226 const NodeDef* current_node = hash_node_list[i];
227 const OpDef* op_def = nullptr;
228 TF_RETURN_IF_ERROR(
229 OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
230 const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
231 if (is_duplicate) {
232 const string original_name = hash_node_list[0]->name();
233 inputs_to_rename[current_node->name() + ":*"] = original_name;
234 any_duplicates_found = true;
235 } else {
236 NodeDef* new_node = merged_graph_def.mutable_node()->Add();
237 *new_node = *current_node;
238 }
239 }
240 }
241 // Update the graph so that any nodes that referred to removed inputs now
242 // pull from the remaining duplicate.
243 TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename,
244 std::unordered_set<string>(),
245 ¤t_graph_def));
246 } while (any_duplicates_found);
247
248 *output_graph_def = current_graph_def;
249
250 return Status::OK();
251 }
252
253 // Looks for the patterns that indicate there are two eight-bit ops feeding into
254 // each other, separated by a conversion up to float and back again. These occur
255 // during the initial conversion of ops to their quantized forms. Because we're
256 // only looking at an individual op in that phase and don't know if its inputs
257 // and outputs are eight-bit-capable, we start by converting the actual op into
258 // quantized form, but add float conversions before and after. This pass gets
259 // rid of those conversions if it turns out we do have adjacent ops capable of
260 // eight-bit processing.
RemoveRedundantQuantizations(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)261 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
262 const TransformFuncContext& context,
263 GraphDef* output_graph_def) {
264 std::set<string> graph_outputs;
265 for (const string& output_name : context.output_names) {
266 graph_outputs.insert(NodeNameFromInput(output_name));
267 }
268 std::map<string, string> inputs_to_rename;
269 GraphDef replaced_graph_def;
270 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
271 input_graph_def, // clang-format off
272 {"QuantizeV2",
273 {
274 {"Dequantize"},
275 {"Min"},
276 {"Max"},
277 }
278 }, // clang-format on
279 [&inputs_to_rename, &graph_outputs](const NodeMatch& match,
280 const std::set<string>& input_nodes,
281 const std::set<string>& output_nodes,
282 std::vector<NodeDef>* new_nodes) {
283 const NodeDef& quantize_node = match.node;
284 const NodeDef& dequantize_node = match.inputs[0].node;
285 inputs_to_rename[quantize_node.name() + ":0"] =
286 dequantize_node.input(0);
287 inputs_to_rename[quantize_node.name() + ":1"] =
288 dequantize_node.input(1);
289 inputs_to_rename[quantize_node.name() + ":2"] =
290 dequantize_node.input(2);
291
292 // Are other sub-graphs using the float intermediate result? If so,
293 // preserve it, but the input renaming still rewires the eight-bit ops
294 // so they don't go through float.
295 if (output_nodes.count(dequantize_node.name()) ||
296 graph_outputs.count(dequantize_node.name())) {
297 CopyOriginalMatch(match, new_nodes);
298 }
299
300 return Status::OK();
301 },
302 {true}, &replaced_graph_def));
303
304 return RenameNodeInputs(replaced_graph_def, inputs_to_rename,
305 std::unordered_set<string>(), output_graph_def);
306 }
307
308 // If the user has passed in the input_min and input_max args, then we need to
309 // convert any input placeholders from float to eight bit, so quantized inputs
310 // can be fed directly into the graph.
QuantizePlaceholders(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)311 Status QuantizePlaceholders(const GraphDef& input_graph_def,
312 const TransformFuncContext& context,
313 GraphDef* output_graph_def) {
314 float input_min;
315 float input_max;
316 bool has_input_range;
317 TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
318 &input_min, &input_max,
319 &has_input_range));
320 if (!has_input_range) {
321 *output_graph_def = input_graph_def;
322 return Status::OK();
323 }
324 std::map<string, string> inputs_to_rename_first_pass;
325 std::map<string, string> inputs_to_rename_second_pass;
326 GraphDef placeholder_graph_def;
327 placeholder_graph_def.Clear();
328 for (const NodeDef& node : input_graph_def.node()) {
329 if (node.op() != "Placeholder") {
330 *(placeholder_graph_def.mutable_node()->Add()) = node;
331 } else {
332 string namespace_prefix = node.name() + "_eightbit";
333
334 NodeDef quantized_placeholder;
335 quantized_placeholder = node;
336 SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
337 *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder;
338
339 NodeDef min_node;
340 min_node.set_op("Const");
341 min_node.set_name(namespace_prefix + "/min");
342 SetNodeAttr("dtype", DT_FLOAT, &min_node);
343 Tensor min_tensor(DT_FLOAT, {});
344 min_tensor.flat<float>()(0) = input_min;
345 SetNodeTensorAttr<float>("value", min_tensor, &min_node);
346 *(placeholder_graph_def.mutable_node()->Add()) = min_node;
347
348 NodeDef max_node;
349 max_node.set_op("Const");
350 max_node.set_name(namespace_prefix + "/max");
351 SetNodeAttr("dtype", DT_FLOAT, &max_node);
352 Tensor max_tensor(DT_FLOAT, {});
353 max_tensor.flat<float>()(0) = input_max;
354 SetNodeTensorAttr<float>("value", max_tensor, &max_node);
355 *(placeholder_graph_def.mutable_node()->Add()) = max_node;
356
357 const string rename_suffix = "__RENAMED_PLACEHOLDER__";
358 NodeDef dequantize_node;
359 dequantize_node.set_op("Dequantize");
360 dequantize_node.set_name(namespace_prefix + "/dequantize");
361 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
362 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
363 AddNodeInput(node.name() + rename_suffix, &dequantize_node);
364 AddNodeInput(min_node.name(), &dequantize_node);
365 AddNodeInput(max_node.name(), &dequantize_node);
366 *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node;
367
368 // First make sure that any internal references to the old placeholder
369 // now point to the dequantize result.
370 inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
371 // Then fix up the dequantize op so that it really points to the
372 // placeholder.
373 inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
374 }
375 }
376
377 GraphDef first_pass_graph_def;
378 TF_RETURN_IF_ERROR(
379 RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
380 std::unordered_set<string>(), &first_pass_graph_def));
381 TF_RETURN_IF_ERROR(
382 RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
383 std::unordered_set<string>(), output_graph_def));
384
385 return Status::OK();
386 }
387
388 // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
389 // an activation layer. To use these during inference, this pass converts those
390 // ops into Requantizes with the trained min/maxes as constant inputs.
ConvertFakeQuantsToRequantize(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)391 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
392 const TransformFuncContext& context,
393 GraphDef* output_graph_def) {
394 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
395 input_graph_def, // clang-format off
396 {"FakeQuantWithMinMaxVars",
397 {
398 {"*"},
399 {"Const"},
400 {"Const"},
401 }
402 }, // clang-format on
403 [](const NodeMatch& match, const std::set<string>& input_nodes,
404 const std::set<string>& output_nodes,
405 std::vector<NodeDef>* new_nodes) {
406 const NodeDef& fake_quant_node = match.node;
407 const NodeDef& original_op_node = match.inputs[0].node;
408 const NodeDef& fake_quant_min_node = match.inputs[1].node;
409 const NodeDef& fake_quant_max_node = match.inputs[2].node;
410
411 string namespace_prefix = fake_quant_node.name() + "_eightbit";
412
413 new_nodes->push_back(original_op_node);
414 new_nodes->push_back(fake_quant_min_node);
415 new_nodes->push_back(fake_quant_max_node);
416
417 NodeDef quantize_node;
418 quantize_node.set_op("QuantizeV2");
419 quantize_node.set_name(namespace_prefix + "/quantize");
420 SetNodeAttr("T", DT_QINT32, &quantize_node);
421 SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
422 AddNodeInput(fake_quant_node.input(0), &quantize_node);
423 AddNodeInput(fake_quant_min_node.name(), &quantize_node);
424 AddNodeInput(fake_quant_max_node.name(), &quantize_node);
425 new_nodes->push_back(quantize_node);
426
427 NodeDef requantize_node;
428 requantize_node.set_op("Requantize");
429 requantize_node.set_name(namespace_prefix + "/requantize");
430 SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
431 SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
432 AddNodeInput(quantize_node.name() + ":0", &requantize_node);
433 AddNodeInput(quantize_node.name() + ":1", &requantize_node);
434 AddNodeInput(quantize_node.name() + ":2", &requantize_node);
435 AddNodeInput(fake_quant_min_node.name(), &requantize_node);
436 AddNodeInput(fake_quant_max_node.name(), &requantize_node);
437 new_nodes->push_back(requantize_node);
438
439 // Convert the 8-bit result back into float for the final output.
440 NodeDef dequantize_node;
441 dequantize_node.set_op("Dequantize");
442 dequantize_node.set_name(fake_quant_node.name());
443 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
444 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
445 AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
446 AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
447 AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
448 new_nodes->push_back(dequantize_node);
449
450 return Status::OK();
451 },
452 {}, output_graph_def));
453
454 return Status::OK();
455 }
456
457 // We always generate Requantize ops driven by dynamic RequantizationRange
458 // calculations when we produce quantized ops like Conv2D or BiasAdd with
459 // 32-bit results. If there were FakeQuant ops already for those activation
460 // layers, then there will be a later Requantize op with constant min/max
461 // inputs, which is preferable for fast inference. This pass looks for those
462 // later Requantize ops, and replaces the dynamic version with them.
MergeAdjacentRequantizes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)463 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
464 const TransformFuncContext& context,
465 GraphDef* output_graph_def) {
466 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
467 input_graph_def, // clang-format off
468 {"Requantize",
469 {
470 {"QuantizeV2",
471 {
472 {"Dequantize",
473 {
474 {"Requantize",
475 {
476 {"*"},
477 {"*"},
478 {"*"},
479 {"RequantizationRange"},
480 {"RequantizationRange"},
481 }
482 },
483 {"Requantize"},
484 {"Requantize"},
485 }
486 },
487 {"Const"},
488 {"Const"},
489 },
490 },
491 {"QuantizeV2"},
492 {"QuantizeV2"},
493 {"Const"},
494 {"Const"},
495 }
496 }, // clang-format on
497 [](const NodeMatch& match, const std::set<string>& input_nodes,
498 const std::set<string>& output_nodes,
499 std::vector<NodeDef>* new_nodes) {
500 const NodeDef& fake_requantize_node = match.node;
501 const NodeDef& original_op_node =
502 match.inputs[0].inputs[0].inputs[0].inputs[0].node;
503 const NodeDef& fake_requantize_min_node = match.inputs[3].node;
504 const NodeDef& fake_requantize_max_node = match.inputs[4].node;
505
506 new_nodes->push_back(original_op_node);
507 new_nodes->push_back(fake_requantize_min_node);
508 new_nodes->push_back(fake_requantize_max_node);
509
510 NodeDef requantize_node;
511 requantize_node = fake_requantize_node;
512 requantize_node.mutable_input()->Clear();
513 AddNodeInput(original_op_node.name() + ":0", &requantize_node);
514 AddNodeInput(original_op_node.name() + ":1", &requantize_node);
515 AddNodeInput(original_op_node.name() + ":2", &requantize_node);
516 AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
517 AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
518 new_nodes->push_back(requantize_node);
519
520 return Status::OK();
521 },
522 {}, output_graph_def));
523
524 return Status::OK();
525 }
526
527 // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
528 // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
529 // op that we want to apply the trained constant conversions to. This pass tries
530 // to move FakeQuant ops up the input chain, so they're as close as possible to
531 // the 32-bit conversion, and so can be easily merged into the automatic dynamic
532 // Requantizes.
HoistFakeQuants(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)533 Status HoistFakeQuants(const GraphDef& input_graph_def,
534 const TransformFuncContext& context,
535 GraphDef* output_graph_def) {
536 GraphDef current_graph_def = input_graph_def;
537 const int max_depth = 3;
538 for (int depth = max_depth; depth > 0; --depth) {
539 OpTypePattern pattern = {"*"};
540 for (int i = 0; i < depth; ++i) {
541 pattern = {"*", {pattern}};
542 }
543 pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
544 GraphDef hoisted_graph_def;
545 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
546 current_graph_def, pattern,
547 [depth](const NodeMatch& match, const std::set<string>& input_nodes,
548 const std::set<string>& output_nodes,
549 std::vector<NodeDef>* new_nodes) {
550 const NodeDef& fake_quant_node = match.node;
551 const NodeDef& fake_quant_min_node = match.inputs[1].node;
552 const NodeDef& fake_quant_max_node = match.inputs[2].node;
553 std::vector<NodeDef> linear_nodes;
554 NodeMatch current_match = match;
555 for (int i = 0; i <= depth; ++i) {
556 linear_nodes.push_back(current_match.inputs[0].node);
557 current_match = current_match.inputs[0];
558 }
559 NodeDef new_fake_quant_node;
560 new_fake_quant_node = fake_quant_node;
561 new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
562 new_fake_quant_node.set_input(
563 0, linear_nodes[linear_nodes.size() - 2].input(0));
564 new_nodes->push_back(new_fake_quant_node);
565
566 new_nodes->push_back(fake_quant_min_node);
567 new_nodes->push_back(fake_quant_max_node);
568
569 linear_nodes[linear_nodes.size() - 2].set_input(
570 0, new_fake_quant_node.name());
571 linear_nodes.front().set_name(fake_quant_node.name());
572 for (const NodeDef& linear_node : linear_nodes) {
573 new_nodes->push_back(linear_node);
574 }
575
576 return Status::OK();
577 },
578 {}, &hoisted_graph_def));
579 current_graph_def = hoisted_graph_def;
580 }
581 *output_graph_def = current_graph_def;
582
583 return Status::OK();
584 }
585
586 // Converts any float ops that have eight-bit equivalents into their quantized
587 // forms, so that as much calculation as possible is done in the lower-precision
588 // format.
QuantizeNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)589 Status QuantizeNodes(const GraphDef& input_graph_def,
590 const TransformFuncContext& context,
591 GraphDef* output_graph_def) {
592 // Loop through all of the quantizable op types, and replace any occurrences
593 // with equivalent sub-graphs with quantized ops at their core. For example
594 // this one-input operation:
595 //
596 // Input(float)
597 // |
598 // v
599 // Operation
600 // |
601 // v
602 // (float)
603 //
604 // Will be turned into it's quantized equivalent:
605 //
606 // Input(float) ReshapeDims
607 // +------v v-------------+
608 // | Reshape
609 // | |
610 // | | ReductionDims
611 // | +-----+ |
612 // | | +---c---------+
613 // | v v v v-------+
614 // | Min Max
615 // | +----+ |
616 // v v v--------+
617 // Quantize
618 // |
619 // v
620 // QuantizedOperation
621 // | | |
622 // v v v
623 // Dequantize
624 // |
625 // v
626 // (float)
627 //
628 // This keeps the inputs and outputs visible to the rest of the graph in
629 // float
630 // and converts them down to quantized buffers internally for the
631 // computation.
632 // The result will end up with a lot of redundant dequantize/quantize pairs
633 // between adjacent quantized ops, but a later pass removes these where it
634 // can.
635
636 std::set<string> ops_to_ignore;
637 if (context.params.count("ignore_op") > 0) {
638 for (const string& name : context.params.at("ignore_op")) {
639 ops_to_ignore.insert(name);
640 }
641 }
642
643 const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
644 string op_pattern;
645 bool is_first = true;
646 std::map<string, QuantizedOpInfo> op_map;
647 for (const QuantizedOpInfo& op_info : op_list) {
648 if (ops_to_ignore.count(op_info.float_name) == 0) {
649 strings::StrAppend(&op_pattern, (is_first ? "" : "|"),
650 op_info.float_name);
651 op_map.insert({op_info.float_name, op_info});
652 is_first = false;
653 }
654 }
655
656 // If input_min and input max have been passed in, then we convert all float
657 // Placeholder nodes into quantized versions, with the supplied values as
658 // their range.
659 GraphDef placeholder_graph_def;
660 TF_RETURN_IF_ERROR(
661 QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
662 TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
663
664 // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
665 // operations like Relu or MaxPool, move them up so that they're as close as
666 // possible to ops with 32-bit outputs like BiasAdd or Conv2D.
667 GraphDef hoisted_graph_def;
668 TF_RETURN_IF_ERROR(
669 HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
670 TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
671
672 // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
673 // activation layers, into Requantize ops with those ranges instead. This
674 // makes it easier to replace the dynamic range calculations that are used
675 // by default.
676 GraphDef converted_graph_def;
677 TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
678 &converted_graph_def));
679 TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
680
681 // If fallback_min and fallback_max are set, then we'll use hardwired ranges
682 // for all the 32-bit to 8-bit requantizations.
683 float fallback_min;
684 float fallback_max;
685 bool has_fallback_range;
686 TF_RETURN_IF_ERROR(ExtractRangeFromParams(
687 context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
688 &has_fallback_range));
689
690 // Replace all occurrences of the current float op with its quantized
691 // equivalent.
692 GraphDef quantized_graph_def;
693 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
694 converted_graph_def, {op_pattern},
695 [&op_map, fallback_min, fallback_max, has_fallback_range](
696 const NodeMatch& match, const std::set<string>& input_nodes,
697 const std::set<string>& output_nodes,
698 std::vector<NodeDef>* new_nodes) {
699 const NodeDef& float_node = match.node;
700 const QuantizedOpInfo& op_info = op_map[float_node.op()];
701
702 DataTypeVector input_types;
703 DataTypeVector output_types;
704 TF_RETURN_IF_ERROR(
705 GetInOutTypes(float_node, &input_types, &output_types));
706 bool are_all_float = true;
707 for (int i = 0; i < float_node.input_size(); ++i) {
708 // Skip any known non-float inputs.
709 if (op_info.unquantized_inputs.count(i)) {
710 continue;
711 }
712 if (input_types[i] != DT_FLOAT) {
713 are_all_float = false;
714 }
715 }
716 for (const DataType& output_type : output_types) {
717 if (output_type != DT_FLOAT) {
718 are_all_float = false;
719 }
720 }
721 // This isn't a float op, so don't quantize it.
722 if (!are_all_float) {
723 CopyOriginalMatch(match, new_nodes);
724 return Status::OK();
725 }
726
727 string namespace_prefix = float_node.name() + "_eightbit";
728
729 // Quantize all of the inputs.
730 std::vector<string> quantized_input_names;
731 for (int i = 0; i < float_node.input_size(); ++i) {
732 // Skip any non-float inputs.
733 if (op_info.unquantized_inputs.count(i)) {
734 continue;
735 }
736
737 const string& input_name = float_node.input(i);
738 string unique_input_name =
739 namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
740
741 // Add some common constants we need for reshaping inputs.
742 NodeDef reshape_dims;
743 reshape_dims.set_op("Const");
744 reshape_dims.set_name(unique_input_name + "/reshape_dims");
745 AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims);
746 SetNodeAttr("dtype", DT_INT32, &reshape_dims);
747 Tensor reshape_dims_tensor(DT_INT32, {1});
748 reshape_dims_tensor.flat<int32>()(0) = -1;
749 SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
750 new_nodes->push_back(reshape_dims);
751
752 NodeDef reduction_dims;
753 reduction_dims.set_op("Const");
754 reduction_dims.set_name(unique_input_name + "/reduction_dims");
755 AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims);
756 SetNodeAttr("dtype", DT_INT32, &reduction_dims);
757 Tensor reduction_dims_tensor(DT_INT32, {1});
758 reduction_dims_tensor.flat<int32>()(0) = 0;
759 SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
760 &reduction_dims);
761 new_nodes->push_back(reduction_dims);
762
763 NodeDef reshape_node;
764 reshape_node.set_op("Reshape");
765 reshape_node.set_name(unique_input_name + "/reshape");
766 SetNodeAttr("T", DT_FLOAT, &reshape_node);
767 AddNodeInput(input_name, &reshape_node);
768 AddNodeInput(reshape_dims.name(), &reshape_node);
769 new_nodes->push_back(reshape_node);
770
771 NodeDef min_node;
772 min_node.set_op("Min");
773 min_node.set_name(unique_input_name + "/min");
774 SetNodeAttr("T", DT_FLOAT, &min_node);
775 SetNodeAttr("keep_dims", false, &min_node);
776 AddNodeInput(reshape_node.name(), &min_node);
777 AddNodeInput(reduction_dims.name(), &min_node);
778 new_nodes->push_back(min_node);
779
780 NodeDef max_node;
781 max_node.set_op("Max");
782 max_node.set_name(unique_input_name + "/max");
783 SetNodeAttr("T", DT_FLOAT, &max_node);
784 SetNodeAttr("keep_dims", false, &max_node);
785 AddNodeInput(reshape_node.name(), &max_node);
786 AddNodeInput(reduction_dims.name(), &max_node);
787 new_nodes->push_back(max_node);
788
789 NodeDef quantize_node;
790 quantize_node.set_op("QuantizeV2");
791 quantize_node.set_name(unique_input_name + "/quantize");
792 SetNodeAttr("T", DT_QUINT8, &quantize_node);
793 SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
794 AddNodeInput(input_name, &quantize_node);
795 AddNodeInput(min_node.name(), &quantize_node);
796 AddNodeInput(max_node.name(), &quantize_node);
797 new_nodes->push_back(quantize_node);
798 quantized_input_names.push_back(quantize_node.name());
799 }
800
801 // Set up the quantized version of the current op.
802 NodeDef quantized_main_node;
803 quantized_main_node.set_op("Quantized" + float_node.op());
804 quantized_main_node.set_name(float_node.name() + "/eightbit");
805 for (const string& attr_to_copy : op_info.attrs_to_copy) {
806 CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
807 &quantized_main_node);
808 }
809 for (const std::pair<string, DataType>& dtype_to_set :
810 op_info.dtypes_to_set) {
811 SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
812 &quantized_main_node);
813 }
814 int quantized_input_index = 0;
815 for (int i = 0; i < float_node.input_size(); ++i) {
816 if (op_info.unquantized_inputs.count(i)) {
817 AddNodeInput(float_node.input(i), &quantized_main_node);
818 } else {
819 const string& quantized_input_name =
820 quantized_input_names[quantized_input_index];
821 AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
822 ++quantized_input_index;
823 }
824 }
825 if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
826 for (const string& quantized_input_name : quantized_input_names) {
827 AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
828 AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
829 }
830 } else {
831 for (const string& quantized_input_name : quantized_input_names) {
832 AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
833 }
834 for (const string& quantized_input_name : quantized_input_names) {
835 AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
836 }
837 }
838 new_nodes->push_back(quantized_main_node);
839
840 string eight_bit_node_name;
841 if (op_info.output_bit_depth == DT_QINT32) {
842 // Shrink the range of the output down from 32 bits to 8.
843 string requantize_min_input;
844 string requantize_max_input;
845 if (has_fallback_range) {
846 // Use constant values for the min/max range if they were given.
847 NodeDef fallback_min_node;
848 fallback_min_node.set_op("Const");
849 fallback_min_node.set_name(quantized_main_node.name() +
850 "/fallback_min");
851 SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
852 Tensor fallback_min_tensor(DT_FLOAT, {});
853 fallback_min_tensor.flat<float>()(0) = fallback_min;
854 SetNodeTensorAttr<float>("value", fallback_min_tensor,
855 &fallback_min_node);
856 new_nodes->push_back(fallback_min_node);
857
858 NodeDef fallback_max_node;
859 fallback_max_node.set_op("Const");
860 fallback_max_node.set_name(quantized_main_node.name() +
861 "/fallback_max");
862 SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
863 Tensor fallback_max_tensor(DT_FLOAT, {});
864 fallback_max_tensor.flat<float>()(0) = fallback_max;
865 SetNodeTensorAttr<float>("value", fallback_max_tensor,
866 &fallback_max_node);
867 new_nodes->push_back(fallback_max_node);
868
869 requantize_min_input = fallback_min_node.name();
870 requantize_max_input = fallback_max_node.name();
871 } else {
872 // Otherwise dynamically measure the range each time.
873 NodeDef requant_range_node;
874 requant_range_node.set_op("RequantizationRange");
875 requant_range_node.set_name(quantized_main_node.name() +
876 "/requant_range");
877 SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
878 AddNodeInput(quantized_main_node.name() + ":0",
879 &requant_range_node);
880 AddNodeInput(quantized_main_node.name() + ":1",
881 &requant_range_node);
882 AddNodeInput(quantized_main_node.name() + ":2",
883 &requant_range_node);
884 new_nodes->push_back(requant_range_node);
885
886 requantize_min_input = requant_range_node.name() + ":0";
887 requantize_max_input = requant_range_node.name() + ":1";
888 }
889 NodeDef requantize_node;
890 requantize_node.set_op("Requantize");
891 requantize_node.set_name(quantized_main_node.name() + "/requantize");
892 SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
893 SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
894 AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
895 AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
896 AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
897 AddNodeInput(requantize_min_input, &requantize_node);
898 AddNodeInput(requantize_max_input, &requantize_node);
899 new_nodes->push_back(requantize_node);
900 eight_bit_node_name = requantize_node.name();
901 } else {
902 eight_bit_node_name = quantized_main_node.name();
903 }
904
905 // Convert the 8-bit result back into float for the final output.
906 NodeDef dequantize_node;
907 dequantize_node.set_op("Dequantize");
908 dequantize_node.set_name(float_node.name());
909 SetNodeAttr("T", DT_QUINT8, &dequantize_node);
910 SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
911 AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
912 AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
913 AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
914 new_nodes->push_back(dequantize_node);
915
916 return Status::OK();
917 },
918 {}, &quantized_graph_def));
919 TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
920
921 // If we've ended up with two Requantize ops in a row (for example if there
922 // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
923 // using the trained range from the second op.
924 GraphDef merged_graph_def;
925 TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
926 &merged_graph_def));
927 TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
928
929 // There can be duplicate quantize nodes if multiple ops pull from a single
930 // input, which makes it harder to remove redundant ones, so strip them out.
931 GraphDef deduped_graph_def;
932 TF_RETURN_IF_ERROR(
933 MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
934 TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
935
936 // Look for Dequantizes that immediately go into Quantizes, and remove them
937 // since the two together cancel each other out. This allows us to keep the
938 // data flow in eight bit where two adjacent ops are in eight bit, but still
939 // keep interoperability with float ops.
940 TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
941 output_graph_def));
942 TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def));
943
944 return Status::OK();
945 }
946
947 REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
948
949 REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
950
951 } // namespace graph_transforms
952 } // namespace tensorflow
953