• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/common_runtime/constant_folding.h"
19 #include "tensorflow/core/common_runtime/graph_constructor.h"
20 #include "tensorflow/core/common_runtime/threadpool_device.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/graph/subgraph.h"
23 #include "tensorflow/core/kernels/quantization_utils.h"
24 #include "tensorflow/core/platform/init_main.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Holds the information we need to translate from a float version of this op
32 // into the quantized equivalent.
33 struct QuantizedOpInfo {
34   // The name of the float op.
35   string float_name;
36   // Which attributes to copy directly over.
37   std::vector<string> attrs_to_copy;
38   // Extra data type attributes we need to set.
39   std::vector<std::pair<string, DataType>> dtypes_to_set;
40   // What depth of inputs the op can read in.
41   DataType input_bit_depth;
42   // The depth of the op's quantized outputs.
43   DataType output_bit_depth;
44   // Which inputs (e.g. shapes) aren't involved in the quantization process.
45   std::set<int32> unquantized_inputs;
46   // How the outputs are arranged, either
47   // [input0, input1, min0, max0, min1, max1] for contiguous, or
48   // [input0, input1, min0, min1, max0, max1] for separate.
49   // The separate order is needed because it's the only way to specify unknown
50   // numbers of inputs for ops like Concat.
51   enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
52 };
53 
54 // Every op that has a quantized equivalent should be listed here, so that the
55 // conversion process can transform them.
GetQuantizedOpList()56 const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
57   static const std::vector<QuantizedOpInfo> op_list = {
58       {"Add",
59        {},
60        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
61        DT_QUINT8,
62        DT_QINT32,
63        {},
64        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
65       {"AvgPool",
66        {"ksize", "strides", "padding"},
67        {{"T", DT_QUINT8}},
68        DT_QUINT8,
69        DT_QUINT8,
70        {},
71        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
72       {"BiasAdd",
73        {},
74        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
75        DT_QUINT8,
76        DT_QINT32,
77        {},
78        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
79       {"Concat",
80        {"N"},
81        {{"T", DT_QUINT8}},
82        DT_QUINT8,
83        DT_QUINT8,
84        {0},
85        QuantizedOpInfo::SEPARATE_MIN_MAX},
86       {"Conv2D",
87        {"strides", "padding"},
88        {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
89        DT_QUINT8,
90        DT_QINT32,
91        {},
92        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
93       {"MatMul",
94        {"transpose_a", "transpose_b"},
95        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
96        DT_QUINT8,
97        DT_QINT32,
98        {},
99        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
100       {"MaxPool",
101        {"ksize", "strides", "padding"},
102        {{"T", DT_QUINT8}},
103        DT_QUINT8,
104        DT_QUINT8,
105        {},
106        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
107       {"Mul",
108        {},
109        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
110        DT_QUINT8,
111        DT_QINT32,
112        {},
113        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
114       {"Relu",
115        {},
116        {{"Tinput", DT_QUINT8}},
117        DT_QUINT8,
118        DT_QUINT8,
119        {},
120        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
121       {"ResizeBilinear",
122        {"align_corners"},
123        {{"T", DT_QUINT8}},
124        DT_QUINT8,
125        DT_QUINT8,
126        {1},
127        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
128       {"Relu6",
129        {},
130        {{"Tinput", DT_QUINT8}},
131        DT_QUINT8,
132        DT_QUINT8,
133        {},
134        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
135       {"Reshape",
136        {},
137        {{"T", DT_QUINT8}},
138        DT_QUINT8,
139        DT_QUINT8,
140        {1},
141        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
142   };
143   return op_list;
144 }
145 
146 namespace {
147 // Replaces invalid characters in input names to get a unique node name.
UniqueNodeNameFromInput(const string & input_name)148 string UniqueNodeNameFromInput(const string& input_name) {
149   string prefix;
150   string node_name;
151   string suffix;
152   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
153   string result;
154   if (prefix == "^") {
155     result += "__hat__";
156   }
157   result += node_name;
158   if (!suffix.empty()) {
159     result += "__port__" + suffix.substr(1, suffix.size() - 1);
160   }
161   return result;
162 }
163 
164 // Pulls two float values from the named parameters, with a lot of checking.
ExtractRangeFromParams(const TransformFuncContext & context,const string & min_name,const string & max_name,float * min_value,float * max_value,bool * has_range)165 Status ExtractRangeFromParams(const TransformFuncContext& context,
166                               const string& min_name, const string& max_name,
167                               float* min_value, float* max_value,
168                               bool* has_range) {
169   // See if we've been given quantized inputs with a known range.
170   const bool has_min = (context.params.count(min_name) != 0);
171   const bool has_max = (context.params.count(max_name) != 0);
172   *has_range = (has_min || has_max);
173   if (!*has_range) {
174     return OkStatus();
175   }
176   if (!has_min || !has_max) {
177     return errors::InvalidArgument("You must pass both ", min_name, " and ",
178                                    max_name, " into quantize_nodes");
179   }
180   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
181   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
182   return OkStatus();
183 }
184 
185 }  // namespace
186 
187 // Analyzes all the nodes in the graph to figure out which ones are duplicates
188 // apart from their names. This commonly includes identical Const nodes, but can
189 // also be simple operations that are repeated on multiple outputs of a
190 // particular node. The complexity is managed using a hash function that avoids
191 // the need for any O(n^2) algorithms when identifying duplicates.
MergeDuplicateNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)192 Status MergeDuplicateNodes(const GraphDef& input_graph_def,
193                            const TransformFuncContext& context,
194                            GraphDef* output_graph_def) {
195   // Make sure we can look up inputs and outputs quickly.
196   std::set<string> input_names(context.input_names.begin(),
197                                context.input_names.end());
198   std::set<string> output_names(context.output_names.begin(),
199                                 context.output_names.end());
200   GraphDef current_graph_def = input_graph_def;
201   // Keep running the merging until no more duplicates are found.
202   bool any_duplicates_found;
203   do {
204     any_duplicates_found = false;
205     // First arrange all of the nodes by a hash of their contents.
206     std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
207     for (const NodeDef& node : current_graph_def.node()) {
208       NodeDef nameless_node = node;
209       // The name matters if it's being used as an input or output node,
210       // otherwise ignore it when looking for duplicates.
211       if (!input_names.count(node.name()) && !output_names.count(node.name())) {
212         nameless_node.set_name("");
213       }
214       const uint64 hash = HashNodeDef(nameless_node);
215       hashed_nodes[hash].push_back(&node);
216     }
217     // If we have multiple nodes with the same hash, then we know they're
218     // duplicates and can be removed, unless they're stateful.
219     std::map<string, string> inputs_to_rename;
220     GraphDef merged_graph_def;
221     for (const std::pair<const uint64, std::vector<const NodeDef*>>&
222              hashed_node_info : hashed_nodes) {
223       const std::vector<const NodeDef*>& hash_node_list =
224           hashed_node_info.second;
225       for (int i = 0; i < hash_node_list.size(); ++i) {
226         const NodeDef* current_node = hash_node_list[i];
227         const OpDef* op_def = nullptr;
228         TF_RETURN_IF_ERROR(
229             OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
230         const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
231         if (is_duplicate) {
232           const string original_name = hash_node_list[0]->name();
233           inputs_to_rename[current_node->name() + ":*"] = original_name;
234           any_duplicates_found = true;
235         } else {
236           NodeDef* new_node = merged_graph_def.mutable_node()->Add();
237           *new_node = *current_node;
238         }
239       }
240     }
241     // Update the graph so that any nodes that referred to removed inputs now
242     // pull from the remaining duplicate.
243     TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename,
244                                         std::unordered_set<string>(),
245                                         &current_graph_def));
246   } while (any_duplicates_found);
247 
248   *output_graph_def = current_graph_def;
249 
250   return OkStatus();
251 }
252 
253 // Looks for the patterns that indicate there are two eight-bit ops feeding into
254 // each other, separated by a conversion up to float and back again. These occur
255 // during the initial conversion of ops to their quantized forms. Because we're
256 // only looking at an individual op in that phase and don't know if its inputs
257 // and outputs are eight-bit-capable, we start by converting the actual op into
258 // quantized form, but add float conversions before and after. This pass gets
259 // rid of those conversions if it turns out we do have adjacent ops capable of
260 // eight-bit processing.
RemoveRedundantQuantizations(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)261 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
262                                     const TransformFuncContext& context,
263                                     GraphDef* output_graph_def) {
264   std::set<string> graph_outputs;
265   for (const string& output_name : context.output_names) {
266     graph_outputs.insert(NodeNameFromInput(output_name));
267   }
268   std::map<string, string> inputs_to_rename;
269   GraphDef replaced_graph_def;
270   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
271       input_graph_def,  // clang-format off
272       {"QuantizeV2",
273         {
274           {"Dequantize"},
275           {"Min"},
276           {"Max"},
277         }
278       },  // clang-format on
279       [&inputs_to_rename, &graph_outputs](const NodeMatch& match,
280                                           const std::set<string>& input_nodes,
281                                           const std::set<string>& output_nodes,
282                                           std::vector<NodeDef>* new_nodes) {
283         const NodeDef& quantize_node = match.node;
284         const NodeDef& dequantize_node = match.inputs[0].node;
285         inputs_to_rename[quantize_node.name() + ":0"] =
286             dequantize_node.input(0);
287         inputs_to_rename[quantize_node.name() + ":1"] =
288             dequantize_node.input(1);
289         inputs_to_rename[quantize_node.name() + ":2"] =
290             dequantize_node.input(2);
291 
292         // Are other sub-graphs using the float intermediate result? If so,
293         // preserve it, but the input renaming still rewires the eight-bit ops
294         // so they don't go through float.
295         if (output_nodes.count(dequantize_node.name()) ||
296             graph_outputs.count(dequantize_node.name())) {
297           CopyOriginalMatch(match, new_nodes);
298         }
299 
300         return OkStatus();
301       },
302       {true}, &replaced_graph_def));
303 
304   return RenameNodeInputs(replaced_graph_def, inputs_to_rename,
305                           std::unordered_set<string>(), output_graph_def);
306 }
307 
308 // If the user has passed in the input_min and input_max args, then we need to
309 // convert any input placeholders from float to eight bit, so quantized inputs
310 // can be fed directly into the graph.
QuantizePlaceholders(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)311 Status QuantizePlaceholders(const GraphDef& input_graph_def,
312                             const TransformFuncContext& context,
313                             GraphDef* output_graph_def) {
314   float input_min;
315   float input_max;
316   bool has_input_range;
317   TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
318                                             &input_min, &input_max,
319                                             &has_input_range));
320   if (!has_input_range) {
321     *output_graph_def = input_graph_def;
322     return OkStatus();
323   }
324   std::map<string, string> inputs_to_rename_first_pass;
325   std::map<string, string> inputs_to_rename_second_pass;
326   GraphDef placeholder_graph_def;
327   placeholder_graph_def.Clear();
328   for (const NodeDef& node : input_graph_def.node()) {
329     if (node.op() != "Placeholder") {
330       *(placeholder_graph_def.mutable_node()->Add()) = node;
331     } else {
332       string namespace_prefix = node.name() + "_eightbit";
333 
334       NodeDef quantized_placeholder;
335       quantized_placeholder = node;
336       SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
337       *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder;
338 
339       NodeDef min_node;
340       min_node.set_op("Const");
341       min_node.set_name(namespace_prefix + "/min");
342       SetNodeAttr("dtype", DT_FLOAT, &min_node);
343       Tensor min_tensor(DT_FLOAT, {});
344       min_tensor.flat<float>()(0) = input_min;
345       SetNodeTensorAttr<float>("value", min_tensor, &min_node);
346       *(placeholder_graph_def.mutable_node()->Add()) = min_node;
347 
348       NodeDef max_node;
349       max_node.set_op("Const");
350       max_node.set_name(namespace_prefix + "/max");
351       SetNodeAttr("dtype", DT_FLOAT, &max_node);
352       Tensor max_tensor(DT_FLOAT, {});
353       max_tensor.flat<float>()(0) = input_max;
354       SetNodeTensorAttr<float>("value", max_tensor, &max_node);
355       *(placeholder_graph_def.mutable_node()->Add()) = max_node;
356 
357       const string rename_suffix = "__RENAMED_PLACEHOLDER__";
358       NodeDef dequantize_node;
359       dequantize_node.set_op("Dequantize");
360       dequantize_node.set_name(namespace_prefix + "/dequantize");
361       SetNodeAttr("T", DT_QUINT8, &dequantize_node);
362       SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
363       AddNodeInput(node.name() + rename_suffix, &dequantize_node);
364       AddNodeInput(min_node.name(), &dequantize_node);
365       AddNodeInput(max_node.name(), &dequantize_node);
366       *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node;
367 
368       // First make sure that any internal references to the old placeholder
369       // now point to the dequantize result.
370       inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
371       // Then fix up the dequantize op so that it really points to the
372       // placeholder.
373       inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
374     }
375   }
376 
377   GraphDef first_pass_graph_def;
378   TF_RETURN_IF_ERROR(
379       RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
380                        std::unordered_set<string>(), &first_pass_graph_def));
381   TF_RETURN_IF_ERROR(
382       RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
383                        std::unordered_set<string>(), output_graph_def));
384 
385   return OkStatus();
386 }
387 
388 // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
389 // an activation layer. To use these during inference, this pass converts those
390 // ops into Requantizes with the trained min/maxes as constant inputs.
ConvertFakeQuantsToRequantize(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)391 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
392                                      const TransformFuncContext& context,
393                                      GraphDef* output_graph_def) {
394   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
395       input_graph_def,  // clang-format off
396       {"FakeQuantWithMinMaxVars",
397         {
398           {"*"},
399           {"Const"},
400           {"Const"},
401         }
402       },  // clang-format on
403       [](const NodeMatch& match, const std::set<string>& input_nodes,
404          const std::set<string>& output_nodes,
405          std::vector<NodeDef>* new_nodes) {
406         const NodeDef& fake_quant_node = match.node;
407         const NodeDef& original_op_node = match.inputs[0].node;
408         const NodeDef& fake_quant_min_node = match.inputs[1].node;
409         const NodeDef& fake_quant_max_node = match.inputs[2].node;
410 
411         string namespace_prefix = fake_quant_node.name() + "_eightbit";
412 
413         new_nodes->push_back(original_op_node);
414         new_nodes->push_back(fake_quant_min_node);
415         new_nodes->push_back(fake_quant_max_node);
416 
417         NodeDef quantize_node;
418         quantize_node.set_op("QuantizeV2");
419         quantize_node.set_name(namespace_prefix + "/quantize");
420         SetNodeAttr("T", DT_QINT32, &quantize_node);
421         SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
422         AddNodeInput(fake_quant_node.input(0), &quantize_node);
423         AddNodeInput(fake_quant_min_node.name(), &quantize_node);
424         AddNodeInput(fake_quant_max_node.name(), &quantize_node);
425         new_nodes->push_back(quantize_node);
426 
427         NodeDef requantize_node;
428         requantize_node.set_op("Requantize");
429         requantize_node.set_name(namespace_prefix + "/requantize");
430         SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
431         SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
432         AddNodeInput(quantize_node.name() + ":0", &requantize_node);
433         AddNodeInput(quantize_node.name() + ":1", &requantize_node);
434         AddNodeInput(quantize_node.name() + ":2", &requantize_node);
435         AddNodeInput(fake_quant_min_node.name(), &requantize_node);
436         AddNodeInput(fake_quant_max_node.name(), &requantize_node);
437         new_nodes->push_back(requantize_node);
438 
439         // Convert the 8-bit result back into float for the final output.
440         NodeDef dequantize_node;
441         dequantize_node.set_op("Dequantize");
442         dequantize_node.set_name(fake_quant_node.name());
443         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
444         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
445         AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
446         AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
447         AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
448         new_nodes->push_back(dequantize_node);
449 
450         return OkStatus();
451       },
452       {}, output_graph_def));
453 
454   return OkStatus();
455 }
456 
457 // We always generate Requantize ops driven by dynamic RequantizationRange
458 // calculations when we produce quantized ops like Conv2D or BiasAdd with
459 // 32-bit results. If there were FakeQuant ops already for those activation
460 // layers, then there will be a later Requantize op with constant min/max
461 // inputs, which is preferable for fast inference. This pass looks for those
462 // later Requantize ops, and replaces the dynamic version with them.
MergeAdjacentRequantizes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)463 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
464                                 const TransformFuncContext& context,
465                                 GraphDef* output_graph_def) {
466   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
467       input_graph_def,  // clang-format off
468       {"Requantize",
469         {
470           {"QuantizeV2",
471             {
472               {"Dequantize",
473                 {
474                   {"Requantize",
475                     {
476                       {"*"},
477                       {"*"},
478                       {"*"},
479                       {"RequantizationRange"},
480                       {"RequantizationRange"},
481                     }
482                   },
483                   {"Requantize"},
484                   {"Requantize"},
485                 }
486               },
487               {"Const"},
488               {"Const"},
489             },
490           },
491           {"QuantizeV2"},
492           {"QuantizeV2"},
493           {"Const"},
494           {"Const"},
495         }
496       },  // clang-format on
497       [](const NodeMatch& match, const std::set<string>& input_nodes,
498          const std::set<string>& output_nodes,
499          std::vector<NodeDef>* new_nodes) {
500         const NodeDef& fake_requantize_node = match.node;
501         const NodeDef& original_op_node =
502             match.inputs[0].inputs[0].inputs[0].inputs[0].node;
503         const NodeDef& fake_requantize_min_node = match.inputs[3].node;
504         const NodeDef& fake_requantize_max_node = match.inputs[4].node;
505 
506         new_nodes->push_back(original_op_node);
507         new_nodes->push_back(fake_requantize_min_node);
508         new_nodes->push_back(fake_requantize_max_node);
509 
510         NodeDef requantize_node;
511         requantize_node = fake_requantize_node;
512         requantize_node.mutable_input()->Clear();
513         AddNodeInput(original_op_node.name() + ":0", &requantize_node);
514         AddNodeInput(original_op_node.name() + ":1", &requantize_node);
515         AddNodeInput(original_op_node.name() + ":2", &requantize_node);
516         AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
517         AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
518         new_nodes->push_back(requantize_node);
519 
520         return OkStatus();
521       },
522       {}, output_graph_def));
523 
524   return OkStatus();
525 }
526 
527 // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
528 // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
529 // op that we want to apply the trained constant conversions to. This pass tries
530 // to move FakeQuant ops up the input chain, so they're as close as possible to
531 // the 32-bit conversion, and so can be easily merged into the automatic dynamic
532 // Requantizes.
HoistFakeQuants(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)533 Status HoistFakeQuants(const GraphDef& input_graph_def,
534                        const TransformFuncContext& context,
535                        GraphDef* output_graph_def) {
536   GraphDef current_graph_def = input_graph_def;
537   const int max_depth = 3;
538   for (int depth = max_depth; depth > 0; --depth) {
539     OpTypePattern pattern = {"*"};
540     for (int i = 0; i < depth; ++i) {
541       pattern = {"*", {pattern}};
542     }
543     pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
544     GraphDef hoisted_graph_def;
545     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
546         current_graph_def, pattern,
547         [depth](const NodeMatch& match, const std::set<string>& input_nodes,
548                 const std::set<string>& output_nodes,
549                 std::vector<NodeDef>* new_nodes) {
550           const NodeDef& fake_quant_node = match.node;
551           const NodeDef& fake_quant_min_node = match.inputs[1].node;
552           const NodeDef& fake_quant_max_node = match.inputs[2].node;
553           std::vector<NodeDef> linear_nodes;
554           NodeMatch current_match = match;
555           for (int i = 0; i <= depth; ++i) {
556             linear_nodes.push_back(current_match.inputs[0].node);
557             current_match = current_match.inputs[0];
558           }
559           NodeDef new_fake_quant_node;
560           new_fake_quant_node = fake_quant_node;
561           new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
562           new_fake_quant_node.set_input(
563               0, linear_nodes[linear_nodes.size() - 2].input(0));
564           new_nodes->push_back(new_fake_quant_node);
565 
566           new_nodes->push_back(fake_quant_min_node);
567           new_nodes->push_back(fake_quant_max_node);
568 
569           linear_nodes[linear_nodes.size() - 2].set_input(
570               0, new_fake_quant_node.name());
571           linear_nodes.front().set_name(fake_quant_node.name());
572           for (const NodeDef& linear_node : linear_nodes) {
573             new_nodes->push_back(linear_node);
574           }
575 
576           return OkStatus();
577         },
578         {}, &hoisted_graph_def));
579     current_graph_def = hoisted_graph_def;
580   }
581   *output_graph_def = current_graph_def;
582 
583   return OkStatus();
584 }
585 
586 // Converts any float ops that have eight-bit equivalents into their quantized
587 // forms, so that as much calculation as possible is done in the lower-precision
588 // format.
QuantizeNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)589 Status QuantizeNodes(const GraphDef& input_graph_def,
590                      const TransformFuncContext& context,
591                      GraphDef* output_graph_def) {
592   // Loop through all of the quantizable op types, and replace any occurrences
593   // with equivalent sub-graphs with quantized ops at their core. For example
594   // this one-input operation:
595   //
596   //            Input(float)
597   //                |
598   //                v
599   //            Operation
600   //                |
601   //                v
602   //             (float)
603   //
604   // Will be turned into it's quantized equivalent:
605   //
606   //      Input(float)          ReshapeDims
607   //         +------v v-------------+
608   //         |    Reshape
609   //         |      |
610   //         |      |          ReductionDims
611   //         |      +-----+         |
612   //         |      | +---c---------+
613   //         |      v v   v v-------+
614   //         |      Min   Max
615   //         |  +----+      |
616   //         v  v  v--------+
617   //        Quantize
618   //            |
619   //            v
620   //     QuantizedOperation
621   //        |   |   |
622   //        v   v   v
623   //        Dequantize
624   //            |
625   //            v
626   //         (float)
627   //
628   // This keeps the inputs and outputs visible to the rest of the graph in
629   // float
630   // and converts them down to quantized buffers internally for the
631   // computation.
632   // The result will end up with a lot of redundant dequantize/quantize pairs
633   // between adjacent quantized ops, but a later pass removes these where it
634   // can.
635 
636   std::set<string> ops_to_ignore;
637   if (context.params.count("ignore_op") > 0) {
638     for (const string& name : context.params.at("ignore_op")) {
639       ops_to_ignore.insert(name);
640     }
641   }
642 
643   const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
644   string op_pattern;
645   bool is_first = true;
646   std::map<string, QuantizedOpInfo> op_map;
647   for (const QuantizedOpInfo& op_info : op_list) {
648     if (ops_to_ignore.count(op_info.float_name) == 0) {
649       strings::StrAppend(&op_pattern, (is_first ? "" : "|"),
650                          op_info.float_name);
651       op_map.insert({op_info.float_name, op_info});
652       is_first = false;
653     }
654   }
655 
656   // If input_min and input max have been passed in, then we convert all float
657   // Placeholder nodes into quantized versions, with the supplied values as
658   // their range.
659   GraphDef placeholder_graph_def;
660   TF_RETURN_IF_ERROR(
661       QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
662   TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
663 
664   // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
665   // operations like Relu or MaxPool, move them up so that they're as close as
666   // possible to ops with 32-bit outputs like BiasAdd or Conv2D.
667   GraphDef hoisted_graph_def;
668   TF_RETURN_IF_ERROR(
669       HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
670   TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
671 
672   // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
673   // activation layers, into Requantize ops with those ranges instead. This
674   // makes it easier to replace the dynamic range calculations that are used
675   // by default.
676   GraphDef converted_graph_def;
677   TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
678                                                    &converted_graph_def));
679   TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
680 
681   // If fallback_min and fallback_max are set, then we'll use hardwired ranges
682   // for all the 32-bit to 8-bit requantizations.
683   float fallback_min;
684   float fallback_max;
685   bool has_fallback_range;
686   TF_RETURN_IF_ERROR(ExtractRangeFromParams(
687       context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
688       &has_fallback_range));
689 
690   // Replace all occurrences of the current float op with its quantized
691   // equivalent.
692   GraphDef quantized_graph_def;
693   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
694       converted_graph_def, {op_pattern},
695       [&op_map, fallback_min, fallback_max, has_fallback_range](
696           const NodeMatch& match, const std::set<string>& input_nodes,
697           const std::set<string>& output_nodes,
698           std::vector<NodeDef>* new_nodes) {
699         const NodeDef& float_node = match.node;
700         const QuantizedOpInfo& op_info = op_map[float_node.op()];
701 
702         DataTypeVector input_types;
703         DataTypeVector output_types;
704         TF_RETURN_IF_ERROR(
705             GetInOutTypes(float_node, &input_types, &output_types));
706         bool are_all_float = true;
707         for (int i = 0; i < float_node.input_size(); ++i) {
708           // Skip any known non-float inputs.
709           if (op_info.unquantized_inputs.count(i)) {
710             continue;
711           }
712           if (i >= input_types.size()) {
713             LOG(ERROR) << "input_types has incorrect size "
714                        << input_types.size() << " <= " << i
715                        << ". Assuming everything else is floats.";
716           }
717           if (i < input_types.size() && input_types[i] != DT_FLOAT) {
718             are_all_float = false;
719           }
720         }
721         for (const DataType& output_type : output_types) {
722           if (output_type != DT_FLOAT) {
723             are_all_float = false;
724           }
725         }
726         // This isn't a float op, so don't quantize it.
727         if (!are_all_float) {
728           CopyOriginalMatch(match, new_nodes);
729           return OkStatus();
730         }
731 
732         string namespace_prefix = float_node.name() + "_eightbit";
733 
734         // Quantize all of the inputs.
735         std::vector<string> quantized_input_names;
736         for (int i = 0; i < float_node.input_size(); ++i) {
737           // Skip any non-float inputs.
738           if (op_info.unquantized_inputs.count(i)) {
739             continue;
740           }
741 
742           const string& input_name = float_node.input(i);
743           string unique_input_name =
744               namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
745 
746           // Add some common constants we need for reshaping inputs.
747           NodeDef reshape_dims;
748           reshape_dims.set_op("Const");
749           reshape_dims.set_name(unique_input_name + "/reshape_dims");
750           AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims);
751           SetNodeAttr("dtype", DT_INT32, &reshape_dims);
752           Tensor reshape_dims_tensor(DT_INT32, {1});
753           reshape_dims_tensor.flat<int32>()(0) = -1;
754           SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
755           new_nodes->push_back(reshape_dims);
756 
757           NodeDef reduction_dims;
758           reduction_dims.set_op("Const");
759           reduction_dims.set_name(unique_input_name + "/reduction_dims");
760           AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims);
761           SetNodeAttr("dtype", DT_INT32, &reduction_dims);
762           Tensor reduction_dims_tensor(DT_INT32, {1});
763           reduction_dims_tensor.flat<int32>()(0) = 0;
764           SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
765                                    &reduction_dims);
766           new_nodes->push_back(reduction_dims);
767 
768           NodeDef reshape_node;
769           reshape_node.set_op("Reshape");
770           reshape_node.set_name(unique_input_name + "/reshape");
771           SetNodeAttr("T", DT_FLOAT, &reshape_node);
772           AddNodeInput(input_name, &reshape_node);
773           AddNodeInput(reshape_dims.name(), &reshape_node);
774           new_nodes->push_back(reshape_node);
775 
776           NodeDef min_node;
777           min_node.set_op("Min");
778           min_node.set_name(unique_input_name + "/min");
779           SetNodeAttr("T", DT_FLOAT, &min_node);
780           SetNodeAttr("keep_dims", false, &min_node);
781           AddNodeInput(reshape_node.name(), &min_node);
782           AddNodeInput(reduction_dims.name(), &min_node);
783           new_nodes->push_back(min_node);
784 
785           NodeDef max_node;
786           max_node.set_op("Max");
787           max_node.set_name(unique_input_name + "/max");
788           SetNodeAttr("T", DT_FLOAT, &max_node);
789           SetNodeAttr("keep_dims", false, &max_node);
790           AddNodeInput(reshape_node.name(), &max_node);
791           AddNodeInput(reduction_dims.name(), &max_node);
792           new_nodes->push_back(max_node);
793 
794           NodeDef quantize_node;
795           quantize_node.set_op("QuantizeV2");
796           quantize_node.set_name(unique_input_name + "/quantize");
797           SetNodeAttr("T", DT_QUINT8, &quantize_node);
798           SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
799           AddNodeInput(input_name, &quantize_node);
800           AddNodeInput(min_node.name(), &quantize_node);
801           AddNodeInput(max_node.name(), &quantize_node);
802           new_nodes->push_back(quantize_node);
803           quantized_input_names.push_back(quantize_node.name());
804         }
805 
806         // Set up the quantized version of the current op.
807         NodeDef quantized_main_node;
808         quantized_main_node.set_op("Quantized" + float_node.op());
809         quantized_main_node.set_name(float_node.name() + "/eightbit");
810         for (const string& attr_to_copy : op_info.attrs_to_copy) {
811           CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
812                        &quantized_main_node);
813         }
814         for (const std::pair<string, DataType>& dtype_to_set :
815              op_info.dtypes_to_set) {
816           SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
817                       &quantized_main_node);
818         }
819         int quantized_input_index = 0;
820         for (int i = 0; i < float_node.input_size(); ++i) {
821           if (op_info.unquantized_inputs.count(i)) {
822             AddNodeInput(float_node.input(i), &quantized_main_node);
823           } else {
824             const string& quantized_input_name =
825                 quantized_input_names[quantized_input_index];
826             AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
827             ++quantized_input_index;
828           }
829         }
830         if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
831           for (const string& quantized_input_name : quantized_input_names) {
832             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
833             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
834           }
835         } else {
836           for (const string& quantized_input_name : quantized_input_names) {
837             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
838           }
839           for (const string& quantized_input_name : quantized_input_names) {
840             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
841           }
842         }
843         new_nodes->push_back(quantized_main_node);
844 
845         string eight_bit_node_name;
846         if (op_info.output_bit_depth == DT_QINT32) {
847           // Shrink the range of the output down from 32 bits to 8.
848           string requantize_min_input;
849           string requantize_max_input;
850           if (has_fallback_range) {
851             // Use constant values for the min/max range if they were given.
852             NodeDef fallback_min_node;
853             fallback_min_node.set_op("Const");
854             fallback_min_node.set_name(quantized_main_node.name() +
855                                        "/fallback_min");
856             SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
857             Tensor fallback_min_tensor(DT_FLOAT, {});
858             fallback_min_tensor.flat<float>()(0) = fallback_min;
859             SetNodeTensorAttr<float>("value", fallback_min_tensor,
860                                      &fallback_min_node);
861             new_nodes->push_back(fallback_min_node);
862 
863             NodeDef fallback_max_node;
864             fallback_max_node.set_op("Const");
865             fallback_max_node.set_name(quantized_main_node.name() +
866                                        "/fallback_max");
867             SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
868             Tensor fallback_max_tensor(DT_FLOAT, {});
869             fallback_max_tensor.flat<float>()(0) = fallback_max;
870             SetNodeTensorAttr<float>("value", fallback_max_tensor,
871                                      &fallback_max_node);
872             new_nodes->push_back(fallback_max_node);
873 
874             requantize_min_input = fallback_min_node.name();
875             requantize_max_input = fallback_max_node.name();
876           } else {
877             // Otherwise dynamically measure the range each time.
878             NodeDef requant_range_node;
879             requant_range_node.set_op("RequantizationRange");
880             requant_range_node.set_name(quantized_main_node.name() +
881                                         "/requant_range");
882             SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
883             AddNodeInput(quantized_main_node.name() + ":0",
884                          &requant_range_node);
885             AddNodeInput(quantized_main_node.name() + ":1",
886                          &requant_range_node);
887             AddNodeInput(quantized_main_node.name() + ":2",
888                          &requant_range_node);
889             new_nodes->push_back(requant_range_node);
890 
891             requantize_min_input = requant_range_node.name() + ":0";
892             requantize_max_input = requant_range_node.name() + ":1";
893           }
894           NodeDef requantize_node;
895           requantize_node.set_op("Requantize");
896           requantize_node.set_name(quantized_main_node.name() + "/requantize");
897           SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
898           SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
899           AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
900           AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
901           AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
902           AddNodeInput(requantize_min_input, &requantize_node);
903           AddNodeInput(requantize_max_input, &requantize_node);
904           new_nodes->push_back(requantize_node);
905           eight_bit_node_name = requantize_node.name();
906         } else {
907           eight_bit_node_name = quantized_main_node.name();
908         }
909 
910         // Convert the 8-bit result back into float for the final output.
911         NodeDef dequantize_node;
912         dequantize_node.set_op("Dequantize");
913         dequantize_node.set_name(float_node.name());
914         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
915         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
916         AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
917         AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
918         AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
919         new_nodes->push_back(dequantize_node);
920 
921         return OkStatus();
922       },
923       {}, &quantized_graph_def));
924   TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
925 
926   // If we've ended up with two Requantize ops in a row (for example if there
927   // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
928   // using the trained range from the second op.
929   GraphDef merged_graph_def;
930   TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
931                                               &merged_graph_def));
932   TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
933 
934   // There can be duplicate quantize nodes if multiple ops pull from a single
935   // input, which makes it harder to remove redundant ones, so strip them out.
936   GraphDef deduped_graph_def;
937   TF_RETURN_IF_ERROR(
938       MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
939   TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
940 
941   // Look for Dequantizes that immediately go into Quantizes, and remove them
942   // since the two together cancel each other out. This allows us to keep the
943   // data flow in eight bit where two adjacent ops are in eight bit, but still
944   // keep interoperability with float ops.
945   TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
946                                                   output_graph_def));
947   TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def));
948 
949   return OkStatus();
950 }
951 
952 REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
953 
954 REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
955 
956 }  // namespace graph_transforms
957 }  // namespace tensorflow
958