• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/common_runtime/constant_folding.h"
19 #include "tensorflow/core/common_runtime/threadpool_device.h"
20 #include "tensorflow/core/graph/graph_constructor.h"
21 #include "tensorflow/core/graph/node_builder.h"
22 #include "tensorflow/core/graph/subgraph.h"
23 #include "tensorflow/core/kernels/quantization_utils.h"
24 #include "tensorflow/core/platform/init_main.h"
25 #include "tensorflow/core/public/session.h"
26 #include "tensorflow/tools/graph_transforms/transform_utils.h"
27 
28 namespace tensorflow {
29 namespace graph_transforms {
30 
31 // Holds the information we need to translate from a float version of this op
32 // into the quantized equivalent.
33 struct QuantizedOpInfo {
34   // The name of the float op.
35   string float_name;
36   // Which attributes to copy directly over.
37   std::vector<string> attrs_to_copy;
38   // Extra data type attributes we need to set.
39   std::vector<std::pair<string, DataType>> dtypes_to_set;
40   // What depth of inputs the op can read in.
41   DataType input_bit_depth;
42   // The depth of the op's quantized outputs.
43   DataType output_bit_depth;
44   // Which inputs (e.g. shapes) aren't involved in the quantization process.
45   std::set<int32> unquantized_inputs;
46   // How the outputs are arranged, either
47   // [input0, input1, min0, max0, min1, max1] for contiguous, or
48   // [input0, input1, min0, min1, max0, max1] for separate.
49   // The separate order is needed because it's the only way to specify unknown
50   // numbers of inputs for ops like Concat.
51   enum { CONTIGUOUS_MIN_MAX, SEPARATE_MIN_MAX } min_max_order;
52 };
53 
54 // Every op that has a quantized equivalent should be listed here, so that the
55 // conversion process can transform them.
GetQuantizedOpList()56 const std::vector<QuantizedOpInfo>& GetQuantizedOpList() {
57   static const std::vector<QuantizedOpInfo> op_list = {
58       {"Add",
59        {},
60        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
61        DT_QUINT8,
62        DT_QINT32,
63        {},
64        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
65       {"AvgPool",
66        {"ksize", "strides", "padding"},
67        {{"T", DT_QUINT8}},
68        DT_QUINT8,
69        DT_QUINT8,
70        {},
71        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
72       {"BiasAdd",
73        {},
74        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"out_type", DT_QINT32}},
75        DT_QUINT8,
76        DT_QINT32,
77        {},
78        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
79       {"Concat",
80        {"N"},
81        {{"T", DT_QUINT8}},
82        DT_QUINT8,
83        DT_QUINT8,
84        {0},
85        QuantizedOpInfo::SEPARATE_MIN_MAX},
86       {"Conv2D",
87        {"strides", "padding"},
88        {{"Tinput", DT_QUINT8}, {"Tfilter", DT_QUINT8}, {"out_type", DT_QINT32}},
89        DT_QUINT8,
90        DT_QINT32,
91        {},
92        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
93       {"MatMul",
94        {"transpose_a", "transpose_b"},
95        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
96        DT_QUINT8,
97        DT_QINT32,
98        {},
99        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
100       {"MaxPool",
101        {"ksize", "strides", "padding"},
102        {{"T", DT_QUINT8}},
103        DT_QUINT8,
104        DT_QUINT8,
105        {},
106        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
107       {"Mul",
108        {},
109        {{"T1", DT_QUINT8}, {"T2", DT_QUINT8}, {"Toutput", DT_QINT32}},
110        DT_QUINT8,
111        DT_QINT32,
112        {},
113        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
114       {"Relu",
115        {},
116        {{"Tinput", DT_QUINT8}},
117        DT_QUINT8,
118        DT_QUINT8,
119        {},
120        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
121       {"ResizeBilinear",
122        {"align_corners"},
123        {{"T", DT_QUINT8}},
124        DT_QUINT8,
125        DT_QUINT8,
126        {1},
127        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
128       {"Relu6",
129        {},
130        {{"Tinput", DT_QUINT8}},
131        DT_QUINT8,
132        DT_QUINT8,
133        {},
134        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
135       {"Reshape",
136        {},
137        {{"T", DT_QUINT8}},
138        DT_QUINT8,
139        DT_QUINT8,
140        {1},
141        QuantizedOpInfo::CONTIGUOUS_MIN_MAX},
142   };
143   return op_list;
144 }
145 
146 namespace {
147 // Replaces invalid characters in input names to get a unique node name.
UniqueNodeNameFromInput(const string & input_name)148 string UniqueNodeNameFromInput(const string& input_name) {
149   string prefix;
150   string node_name;
151   string suffix;
152   NodeNamePartsFromInput(input_name, &prefix, &node_name, &suffix);
153   string result;
154   if (prefix == "^") {
155     result += "__hat__";
156   }
157   result += node_name;
158   if (!suffix.empty()) {
159     result += "__port__" + suffix.substr(1, suffix.size() - 1);
160   }
161   return result;
162 }
163 
164 // Pulls two float values from the named parameters, with a lot of checking.
ExtractRangeFromParams(const TransformFuncContext & context,const string & min_name,const string & max_name,float * min_value,float * max_value,bool * has_range)165 Status ExtractRangeFromParams(const TransformFuncContext& context,
166                               const string& min_name, const string& max_name,
167                               float* min_value, float* max_value,
168                               bool* has_range) {
169   // See if we've been given quantized inputs with a known range.
170   const bool has_min = (context.params.count(min_name) != 0);
171   const bool has_max = (context.params.count(max_name) != 0);
172   *has_range = (has_min || has_max);
173   if (!*has_range) {
174     return Status::OK();
175   }
176   if (!has_min || !has_max) {
177     return errors::InvalidArgument("You must pass both ", min_name, " and ",
178                                    max_name, " into quantize_nodes");
179   }
180   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(min_name, 0.0f, min_value));
181   TF_RETURN_IF_ERROR(context.GetOneFloatParameter(max_name, 0.0f, max_value));
182   return Status::OK();
183 }
184 
185 }  // namespace
186 
187 // Analyzes all the nodes in the graph to figure out which ones are duplicates
188 // apart from their names. This commonly includes identical Const nodes, but can
189 // also be simple operations that are repeated on multiple outputs of a
190 // particular node. The complexity is managed using a hash function that avoids
191 // the need for any O(n^2) algorithms when identifying duplicates.
MergeDuplicateNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)192 Status MergeDuplicateNodes(const GraphDef& input_graph_def,
193                            const TransformFuncContext& context,
194                            GraphDef* output_graph_def) {
195   // Make sure we can look up inputs and outputs quickly.
196   std::set<string> input_names(context.input_names.begin(),
197                                context.input_names.end());
198   std::set<string> output_names(context.output_names.begin(),
199                                 context.output_names.end());
200   GraphDef current_graph_def = input_graph_def;
201   // Keep running the merging until no more duplicates are found.
202   bool any_duplicates_found;
203   do {
204     any_duplicates_found = false;
205     // First arrange all of the nodes by a hash of their contents.
206     std::map<uint64, std::vector<const NodeDef*>> hashed_nodes;
207     for (const NodeDef& node : current_graph_def.node()) {
208       NodeDef nameless_node = node;
209       // The name matters if it's being used as an input or output node,
210       // otherwise ignore it when looking for duplicates.
211       if (!input_names.count(node.name()) && !output_names.count(node.name())) {
212         nameless_node.set_name("");
213       }
214       const uint64 hash = HashNodeDef(nameless_node);
215       hashed_nodes[hash].push_back(&node);
216     }
217     // If we have multiple nodes with the same hash, then we know they're
218     // duplicates and can be removed, unless they're stateful.
219     std::map<string, string> inputs_to_rename;
220     GraphDef merged_graph_def;
221     for (const std::pair<uint64, std::vector<const NodeDef*>> hashed_node_info :
222          hashed_nodes) {
223       const std::vector<const NodeDef*>& hash_node_list =
224           hashed_node_info.second;
225       for (int i = 0; i < hash_node_list.size(); ++i) {
226         const NodeDef* current_node = hash_node_list[i];
227         const OpDef* op_def = nullptr;
228         TF_RETURN_IF_ERROR(
229             OpRegistry::Global()->LookUpOpDef(current_node->op(), &op_def));
230         const bool is_duplicate = ((!op_def->is_stateful()) && (i > 0));
231         if (is_duplicate) {
232           const string original_name = hash_node_list[0]->name();
233           inputs_to_rename[current_node->name() + ":*"] = original_name;
234           any_duplicates_found = true;
235         } else {
236           NodeDef* new_node = merged_graph_def.mutable_node()->Add();
237           *new_node = *current_node;
238         }
239       }
240     }
241     // Update the graph so that any nodes that referred to removed inputs now
242     // pull from the remaining duplicate.
243     TF_RETURN_IF_ERROR(RenameNodeInputs(merged_graph_def, inputs_to_rename,
244                                         std::unordered_set<string>(),
245                                         &current_graph_def));
246   } while (any_duplicates_found);
247 
248   *output_graph_def = current_graph_def;
249 
250   return Status::OK();
251 }
252 
253 // Looks for the patterns that indicate there are two eight-bit ops feeding into
254 // each other, separated by a conversion up to float and back again. These occur
255 // during the initial conversion of ops to their quantized forms. Because we're
256 // only looking at an individual op in that phase and don't know if its inputs
257 // and outputs are eight-bit-capable, we start by converting the actual op into
258 // quantized form, but add float conversions before and after. This pass gets
259 // rid of those conversions if it turns out we do have adjacent ops capable of
260 // eight-bit processing.
RemoveRedundantQuantizations(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)261 Status RemoveRedundantQuantizations(const GraphDef& input_graph_def,
262                                     const TransformFuncContext& context,
263                                     GraphDef* output_graph_def) {
264   std::set<string> graph_outputs;
265   for (const string& output_name : context.output_names) {
266     graph_outputs.insert(NodeNameFromInput(output_name));
267   }
268   std::map<string, string> inputs_to_rename;
269   GraphDef replaced_graph_def;
270   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
271       input_graph_def,  // clang-format off
272       {"QuantizeV2",
273         {
274           {"Dequantize"},
275           {"Min"},
276           {"Max"},
277         }
278       },  // clang-format on
279       [&inputs_to_rename, &graph_outputs](const NodeMatch& match,
280                                           const std::set<string>& input_nodes,
281                                           const std::set<string>& output_nodes,
282                                           std::vector<NodeDef>* new_nodes) {
283         const NodeDef& quantize_node = match.node;
284         const NodeDef& dequantize_node = match.inputs[0].node;
285         inputs_to_rename[quantize_node.name() + ":0"] =
286             dequantize_node.input(0);
287         inputs_to_rename[quantize_node.name() + ":1"] =
288             dequantize_node.input(1);
289         inputs_to_rename[quantize_node.name() + ":2"] =
290             dequantize_node.input(2);
291 
292         // Are other sub-graphs using the float intermediate result? If so,
293         // preserve it, but the input renaming still rewires the eight-bit ops
294         // so they don't go through float.
295         if (output_nodes.count(dequantize_node.name()) ||
296             graph_outputs.count(dequantize_node.name())) {
297           CopyOriginalMatch(match, new_nodes);
298         }
299 
300         return Status::OK();
301       },
302       {true}, &replaced_graph_def));
303 
304   return RenameNodeInputs(replaced_graph_def, inputs_to_rename,
305                           std::unordered_set<string>(), output_graph_def);
306 }
307 
308 // If the user has passed in the input_min and input_max args, then we need to
309 // convert any input placeholders from float to eight bit, so quantized inputs
310 // can be fed directly into the graph.
QuantizePlaceholders(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)311 Status QuantizePlaceholders(const GraphDef& input_graph_def,
312                             const TransformFuncContext& context,
313                             GraphDef* output_graph_def) {
314   float input_min;
315   float input_max;
316   bool has_input_range;
317   TF_RETURN_IF_ERROR(ExtractRangeFromParams(context, "input_min", "input_max",
318                                             &input_min, &input_max,
319                                             &has_input_range));
320   if (!has_input_range) {
321     *output_graph_def = input_graph_def;
322     return Status::OK();
323   }
324   std::map<string, string> inputs_to_rename_first_pass;
325   std::map<string, string> inputs_to_rename_second_pass;
326   GraphDef placeholder_graph_def;
327   placeholder_graph_def.Clear();
328   for (const NodeDef& node : input_graph_def.node()) {
329     if (node.op() != "Placeholder") {
330       *(placeholder_graph_def.mutable_node()->Add()) = node;
331     } else {
332       string namespace_prefix = node.name() + "_eightbit";
333 
334       NodeDef quantized_placeholder;
335       quantized_placeholder = node;
336       SetNodeAttr("dtype", DT_QUINT8, &quantized_placeholder);
337       *(placeholder_graph_def.mutable_node()->Add()) = quantized_placeholder;
338 
339       NodeDef min_node;
340       min_node.set_op("Const");
341       min_node.set_name(namespace_prefix + "/min");
342       SetNodeAttr("dtype", DT_FLOAT, &min_node);
343       Tensor min_tensor(DT_FLOAT, {});
344       min_tensor.flat<float>()(0) = input_min;
345       SetNodeTensorAttr<float>("value", min_tensor, &min_node);
346       *(placeholder_graph_def.mutable_node()->Add()) = min_node;
347 
348       NodeDef max_node;
349       max_node.set_op("Const");
350       max_node.set_name(namespace_prefix + "/max");
351       SetNodeAttr("dtype", DT_FLOAT, &max_node);
352       Tensor max_tensor(DT_FLOAT, {});
353       max_tensor.flat<float>()(0) = input_max;
354       SetNodeTensorAttr<float>("value", max_tensor, &max_node);
355       *(placeholder_graph_def.mutable_node()->Add()) = max_node;
356 
357       const string rename_suffix = "__RENAMED_PLACEHOLDER__";
358       NodeDef dequantize_node;
359       dequantize_node.set_op("Dequantize");
360       dequantize_node.set_name(namespace_prefix + "/dequantize");
361       SetNodeAttr("T", DT_QUINT8, &dequantize_node);
362       SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
363       AddNodeInput(node.name() + rename_suffix, &dequantize_node);
364       AddNodeInput(min_node.name(), &dequantize_node);
365       AddNodeInput(max_node.name(), &dequantize_node);
366       *(placeholder_graph_def.mutable_node()->Add()) = dequantize_node;
367 
368       // First make sure that any internal references to the old placeholder
369       // now point to the dequantize result.
370       inputs_to_rename_first_pass[node.name()] = dequantize_node.name();
371       // Then fix up the dequantize op so that it really points to the
372       // placeholder.
373       inputs_to_rename_second_pass[node.name() + rename_suffix] = node.name();
374     }
375   }
376 
377   GraphDef first_pass_graph_def;
378   TF_RETURN_IF_ERROR(
379       RenameNodeInputs(placeholder_graph_def, inputs_to_rename_first_pass,
380                        std::unordered_set<string>(), &first_pass_graph_def));
381   TF_RETURN_IF_ERROR(
382       RenameNodeInputs(first_pass_graph_def, inputs_to_rename_second_pass,
383                        std::unordered_set<string>(), output_graph_def));
384 
385   return Status::OK();
386 }
387 
388 // During training, FakeQuantWithMinMaxVars ops capture a good min/max range for
389 // an activation layer. To use these during inference, this pass converts those
390 // ops into Requantizes with the trained min/maxes as constant inputs.
ConvertFakeQuantsToRequantize(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)391 Status ConvertFakeQuantsToRequantize(const GraphDef& input_graph_def,
392                                      const TransformFuncContext& context,
393                                      GraphDef* output_graph_def) {
394   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
395       input_graph_def,  // clang-format off
396       {"FakeQuantWithMinMaxVars",
397         {
398           {"*"},
399           {"Const"},
400           {"Const"},
401         }
402       },  // clang-format on
403       [](const NodeMatch& match, const std::set<string>& input_nodes,
404          const std::set<string>& output_nodes,
405          std::vector<NodeDef>* new_nodes) {
406         const NodeDef& fake_quant_node = match.node;
407         const NodeDef& original_op_node = match.inputs[0].node;
408         const NodeDef& fake_quant_min_node = match.inputs[1].node;
409         const NodeDef& fake_quant_max_node = match.inputs[2].node;
410 
411         string namespace_prefix = fake_quant_node.name() + "_eightbit";
412 
413         new_nodes->push_back(original_op_node);
414         new_nodes->push_back(fake_quant_min_node);
415         new_nodes->push_back(fake_quant_max_node);
416 
417         NodeDef quantize_node;
418         quantize_node.set_op("QuantizeV2");
419         quantize_node.set_name(namespace_prefix + "/quantize");
420         SetNodeAttr("T", DT_QINT32, &quantize_node);
421         SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
422         AddNodeInput(fake_quant_node.input(0), &quantize_node);
423         AddNodeInput(fake_quant_min_node.name(), &quantize_node);
424         AddNodeInput(fake_quant_max_node.name(), &quantize_node);
425         new_nodes->push_back(quantize_node);
426 
427         NodeDef requantize_node;
428         requantize_node.set_op("Requantize");
429         requantize_node.set_name(namespace_prefix + "/requantize");
430         SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
431         SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
432         AddNodeInput(quantize_node.name() + ":0", &requantize_node);
433         AddNodeInput(quantize_node.name() + ":1", &requantize_node);
434         AddNodeInput(quantize_node.name() + ":2", &requantize_node);
435         AddNodeInput(fake_quant_min_node.name(), &requantize_node);
436         AddNodeInput(fake_quant_max_node.name(), &requantize_node);
437         new_nodes->push_back(requantize_node);
438 
439         // Convert the 8-bit result back into float for the final output.
440         NodeDef dequantize_node;
441         dequantize_node.set_op("Dequantize");
442         dequantize_node.set_name(fake_quant_node.name());
443         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
444         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
445         AddNodeInput(requantize_node.name() + ":0", &dequantize_node);
446         AddNodeInput(requantize_node.name() + ":1", &dequantize_node);
447         AddNodeInput(requantize_node.name() + ":2", &dequantize_node);
448         new_nodes->push_back(dequantize_node);
449 
450         return Status::OK();
451       },
452       {}, output_graph_def));
453 
454   return Status::OK();
455 }
456 
457 // We always generate Requantize ops driven by dynamic RequantizationRange
458 // calculations when we produce quantized ops like Conv2D or BiasAdd with
459 // 32-bit results. If there were FakeQuant ops already for those activation
460 // layers, then there will be a later Requantize op with constant min/max
461 // inputs, which is preferable for fast inference. This pass looks for those
462 // later Requantize ops, and replaces the dynamic version with them.
MergeAdjacentRequantizes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)463 Status MergeAdjacentRequantizes(const GraphDef& input_graph_def,
464                                 const TransformFuncContext& context,
465                                 GraphDef* output_graph_def) {
466   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
467       input_graph_def,  // clang-format off
468       {"Requantize",
469         {
470           {"QuantizeV2",
471             {
472               {"Dequantize",
473                 {
474                   {"Requantize",
475                     {
476                       {"*"},
477                       {"*"},
478                       {"*"},
479                       {"RequantizationRange"},
480                       {"RequantizationRange"},
481                     }
482                   },
483                   {"Requantize"},
484                   {"Requantize"},
485                 }
486               },
487               {"Const"},
488               {"Const"},
489             },
490           },
491           {"QuantizeV2"},
492           {"QuantizeV2"},
493           {"Const"},
494           {"Const"},
495         }
496       },  // clang-format on
497       [](const NodeMatch& match, const std::set<string>& input_nodes,
498          const std::set<string>& output_nodes,
499          std::vector<NodeDef>* new_nodes) {
500         const NodeDef& fake_requantize_node = match.node;
501         const NodeDef& original_op_node =
502             match.inputs[0].inputs[0].inputs[0].inputs[0].node;
503         const NodeDef& fake_requantize_min_node = match.inputs[3].node;
504         const NodeDef& fake_requantize_max_node = match.inputs[4].node;
505 
506         new_nodes->push_back(original_op_node);
507         new_nodes->push_back(fake_requantize_min_node);
508         new_nodes->push_back(fake_requantize_max_node);
509 
510         NodeDef requantize_node;
511         requantize_node = fake_requantize_node;
512         requantize_node.mutable_input()->Clear();
513         AddNodeInput(original_op_node.name() + ":0", &requantize_node);
514         AddNodeInput(original_op_node.name() + ":1", &requantize_node);
515         AddNodeInput(original_op_node.name() + ":2", &requantize_node);
516         AddNodeInput(fake_requantize_min_node.name(), &requantize_node);
517         AddNodeInput(fake_requantize_max_node.name(), &requantize_node);
518         new_nodes->push_back(requantize_node);
519 
520         return Status::OK();
521       },
522       {}, output_graph_def));
523 
524   return Status::OK();
525 }
526 
527 // Sometimes FakeQuantWithMinMaxVars ops are added at the end of a chain of
528 // linear ops like Relu, MaxPool, etc, several steps from the Conv2D or BiasAdd
529 // op that we want to apply the trained constant conversions to. This pass tries
530 // to move FakeQuant ops up the input chain, so they're as close as possible to
531 // the 32-bit conversion, and so can be easily merged into the automatic dynamic
532 // Requantizes.
HoistFakeQuants(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)533 Status HoistFakeQuants(const GraphDef& input_graph_def,
534                        const TransformFuncContext& context,
535                        GraphDef* output_graph_def) {
536   GraphDef current_graph_def = input_graph_def;
537   const int max_depth = 3;
538   for (int depth = max_depth; depth > 0; --depth) {
539     OpTypePattern pattern = {"*"};
540     for (int i = 0; i < depth; ++i) {
541       pattern = {"*", {pattern}};
542     }
543     pattern = {"FakeQuantWithMinMaxVars", {pattern, {"Const"}, {"Const"}}};
544     GraphDef hoisted_graph_def;
545     TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
546         current_graph_def, pattern,
547         [depth](const NodeMatch& match, const std::set<string>& input_nodes,
548                 const std::set<string>& output_nodes,
549                 std::vector<NodeDef>* new_nodes) {
550           const NodeDef& fake_quant_node = match.node;
551           const NodeDef& fake_quant_min_node = match.inputs[1].node;
552           const NodeDef& fake_quant_max_node = match.inputs[2].node;
553           std::vector<NodeDef> linear_nodes;
554           NodeMatch current_match = match;
555           for (int i = 0; i <= depth; ++i) {
556             linear_nodes.push_back(current_match.inputs[0].node);
557             current_match = current_match.inputs[0];
558           }
559           NodeDef new_fake_quant_node;
560           new_fake_quant_node = fake_quant_node;
561           new_fake_quant_node.set_name(fake_quant_node.name() + "_hoisted");
562           new_fake_quant_node.set_input(
563               0, linear_nodes[linear_nodes.size() - 2].input(0));
564           new_nodes->push_back(new_fake_quant_node);
565 
566           new_nodes->push_back(fake_quant_min_node);
567           new_nodes->push_back(fake_quant_max_node);
568 
569           linear_nodes[linear_nodes.size() - 2].set_input(
570               0, new_fake_quant_node.name());
571           linear_nodes.front().set_name(fake_quant_node.name());
572           for (const NodeDef& linear_node : linear_nodes) {
573             new_nodes->push_back(linear_node);
574           }
575 
576           return Status::OK();
577         },
578         {}, &hoisted_graph_def));
579     current_graph_def = hoisted_graph_def;
580   }
581   *output_graph_def = current_graph_def;
582 
583   return Status::OK();
584 }
585 
586 // Converts any float ops that have eight-bit equivalents into their quantized
587 // forms, so that as much calculation as possible is done in the lower-precision
588 // format.
QuantizeNodes(const GraphDef & input_graph_def,const TransformFuncContext & context,GraphDef * output_graph_def)589 Status QuantizeNodes(const GraphDef& input_graph_def,
590                      const TransformFuncContext& context,
591                      GraphDef* output_graph_def) {
592   // Loop through all of the quantizable op types, and replace any occurrences
593   // with equivalent sub-graphs with quantized ops at their core. For example
594   // this one-input operation:
595   //
596   //            Input(float)
597   //                |
598   //                v
599   //            Operation
600   //                |
601   //                v
602   //             (float)
603   //
604   // Will be turned into it's quantized equivalent:
605   //
606   //      Input(float)          ReshapeDims
607   //         +------v v-------------+
608   //         |    Reshape
609   //         |      |
610   //         |      |          ReductionDims
611   //         |      +-----+         |
612   //         |      | +---c---------+
613   //         |      v v   v v-------+
614   //         |      Min   Max
615   //         |  +----+      |
616   //         v  v  v--------+
617   //        Quantize
618   //            |
619   //            v
620   //     QuantizedOperation
621   //        |   |   |
622   //        v   v   v
623   //        Dequantize
624   //            |
625   //            v
626   //         (float)
627   //
628   // This keeps the inputs and outputs visible to the rest of the graph in
629   // float
630   // and converts them down to quantized buffers internally for the
631   // computation.
632   // The result will end up with a lot of redundant dequantize/quantize pairs
633   // between adjacent quantized ops, but a later pass removes these where it
634   // can.
635 
636   std::set<string> ops_to_ignore;
637   if (context.params.count("ignore_op") > 0) {
638     for (const string& name : context.params.at("ignore_op")) {
639       ops_to_ignore.insert(name);
640     }
641   }
642 
643   const std::vector<QuantizedOpInfo>& op_list = GetQuantizedOpList();
644   string op_pattern;
645   bool is_first = true;
646   std::map<string, QuantizedOpInfo> op_map;
647   for (const QuantizedOpInfo& op_info : op_list) {
648     if (ops_to_ignore.count(op_info.float_name) == 0) {
649       strings::StrAppend(&op_pattern, (is_first ? "" : "|"),
650                          op_info.float_name);
651       op_map.insert({op_info.float_name, op_info});
652       is_first = false;
653     }
654   }
655 
656   // If input_min and input max have been passed in, then we convert all float
657   // Placeholder nodes into quantized versions, with the supplied values as
658   // their range.
659   GraphDef placeholder_graph_def;
660   TF_RETURN_IF_ERROR(
661       QuantizePlaceholders(input_graph_def, context, &placeholder_graph_def));
662   TF_RETURN_IF_ERROR(IsGraphValid(placeholder_graph_def));
663 
664   // If there are any FakeQuantWithMinMaxVars at the end of a chain of linear
665   // operations like Relu or MaxPool, move them up so that they're as close as
666   // possible to ops with 32-bit outputs like BiasAdd or Conv2D.
667   GraphDef hoisted_graph_def;
668   TF_RETURN_IF_ERROR(
669       HoistFakeQuants(placeholder_graph_def, context, &hoisted_graph_def));
670   TF_RETURN_IF_ERROR(IsGraphValid(hoisted_graph_def));
671 
672   // Convert any FakeQuantWithMinMaxVars, which hold the trained ranges of
673   // activation layers, into Requantize ops with those ranges instead. This
674   // makes it easier to replace the dynamic range calculations that are used
675   // by default.
676   GraphDef converted_graph_def;
677   TF_RETURN_IF_ERROR(ConvertFakeQuantsToRequantize(hoisted_graph_def, context,
678                                                    &converted_graph_def));
679   TF_RETURN_IF_ERROR(IsGraphValid(converted_graph_def));
680 
681   // If fallback_min and fallback_max are set, then we'll use hardwired ranges
682   // for all the 32-bit to 8-bit requantizations.
683   float fallback_min;
684   float fallback_max;
685   bool has_fallback_range;
686   TF_RETURN_IF_ERROR(ExtractRangeFromParams(
687       context, "fallback_min", "fallback_max", &fallback_min, &fallback_max,
688       &has_fallback_range));
689 
690   // Replace all occurrences of the current float op with its quantized
691   // equivalent.
692   GraphDef quantized_graph_def;
693   TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
694       converted_graph_def, {op_pattern},
695       [&op_map, fallback_min, fallback_max, has_fallback_range](
696           const NodeMatch& match, const std::set<string>& input_nodes,
697           const std::set<string>& output_nodes,
698           std::vector<NodeDef>* new_nodes) {
699         const NodeDef& float_node = match.node;
700         const QuantizedOpInfo& op_info = op_map[float_node.op()];
701 
702         DataTypeVector input_types;
703         DataTypeVector output_types;
704         TF_RETURN_IF_ERROR(
705             GetInOutTypes(float_node, &input_types, &output_types));
706         bool are_all_float = true;
707         for (int i = 0; i < float_node.input_size(); ++i) {
708           // Skip any known non-float inputs.
709           if (op_info.unquantized_inputs.count(i)) {
710             continue;
711           }
712           if (input_types[i] != DT_FLOAT) {
713             are_all_float = false;
714           }
715         }
716         for (const DataType& output_type : output_types) {
717           if (output_type != DT_FLOAT) {
718             are_all_float = false;
719           }
720         }
721         // This isn't a float op, so don't quantize it.
722         if (!are_all_float) {
723           CopyOriginalMatch(match, new_nodes);
724           return Status::OK();
725         }
726 
727         string namespace_prefix = float_node.name() + "_eightbit";
728 
729         // Quantize all of the inputs.
730         std::vector<string> quantized_input_names;
731         for (int i = 0; i < float_node.input_size(); ++i) {
732           // Skip any non-float inputs.
733           if (op_info.unquantized_inputs.count(i)) {
734             continue;
735           }
736 
737           const string& input_name = float_node.input(i);
738           string unique_input_name =
739               namespace_prefix + "/" + UniqueNodeNameFromInput(input_name);
740 
741           // Add some common constants we need for reshaping inputs.
742           NodeDef reshape_dims;
743           reshape_dims.set_op("Const");
744           reshape_dims.set_name(unique_input_name + "/reshape_dims");
745           AddNodeInput("^" + NodeNameFromInput(input_name), &reshape_dims);
746           SetNodeAttr("dtype", DT_INT32, &reshape_dims);
747           Tensor reshape_dims_tensor(DT_INT32, {1});
748           reshape_dims_tensor.flat<int32>()(0) = -1;
749           SetNodeTensorAttr<int32>("value", reshape_dims_tensor, &reshape_dims);
750           new_nodes->push_back(reshape_dims);
751 
752           NodeDef reduction_dims;
753           reduction_dims.set_op("Const");
754           reduction_dims.set_name(unique_input_name + "/reduction_dims");
755           AddNodeInput("^" + NodeNameFromInput(input_name), &reduction_dims);
756           SetNodeAttr("dtype", DT_INT32, &reduction_dims);
757           Tensor reduction_dims_tensor(DT_INT32, {1});
758           reduction_dims_tensor.flat<int32>()(0) = 0;
759           SetNodeTensorAttr<int32>("value", reduction_dims_tensor,
760                                    &reduction_dims);
761           new_nodes->push_back(reduction_dims);
762 
763           NodeDef reshape_node;
764           reshape_node.set_op("Reshape");
765           reshape_node.set_name(unique_input_name + "/reshape");
766           SetNodeAttr("T", DT_FLOAT, &reshape_node);
767           AddNodeInput(input_name, &reshape_node);
768           AddNodeInput(reshape_dims.name(), &reshape_node);
769           new_nodes->push_back(reshape_node);
770 
771           NodeDef min_node;
772           min_node.set_op("Min");
773           min_node.set_name(unique_input_name + "/min");
774           SetNodeAttr("T", DT_FLOAT, &min_node);
775           SetNodeAttr("keep_dims", false, &min_node);
776           AddNodeInput(reshape_node.name(), &min_node);
777           AddNodeInput(reduction_dims.name(), &min_node);
778           new_nodes->push_back(min_node);
779 
780           NodeDef max_node;
781           max_node.set_op("Max");
782           max_node.set_name(unique_input_name + "/max");
783           SetNodeAttr("T", DT_FLOAT, &max_node);
784           SetNodeAttr("keep_dims", false, &max_node);
785           AddNodeInput(reshape_node.name(), &max_node);
786           AddNodeInput(reduction_dims.name(), &max_node);
787           new_nodes->push_back(max_node);
788 
789           NodeDef quantize_node;
790           quantize_node.set_op("QuantizeV2");
791           quantize_node.set_name(unique_input_name + "/quantize");
792           SetNodeAttr("T", DT_QUINT8, &quantize_node);
793           SetNodeAttr("mode", "MIN_FIRST", &quantize_node);
794           AddNodeInput(input_name, &quantize_node);
795           AddNodeInput(min_node.name(), &quantize_node);
796           AddNodeInput(max_node.name(), &quantize_node);
797           new_nodes->push_back(quantize_node);
798           quantized_input_names.push_back(quantize_node.name());
799         }
800 
801         // Set up the quantized version of the current op.
802         NodeDef quantized_main_node;
803         quantized_main_node.set_op("Quantized" + float_node.op());
804         quantized_main_node.set_name(float_node.name() + "/eightbit");
805         for (const string& attr_to_copy : op_info.attrs_to_copy) {
806           CopyNodeAttr(float_node, attr_to_copy, attr_to_copy,
807                        &quantized_main_node);
808         }
809         for (const std::pair<string, DataType>& dtype_to_set :
810              op_info.dtypes_to_set) {
811           SetNodeAttr(dtype_to_set.first, dtype_to_set.second,
812                       &quantized_main_node);
813         }
814         int quantized_input_index = 0;
815         for (int i = 0; i < float_node.input_size(); ++i) {
816           if (op_info.unquantized_inputs.count(i)) {
817             AddNodeInput(float_node.input(i), &quantized_main_node);
818           } else {
819             const string& quantized_input_name =
820                 quantized_input_names[quantized_input_index];
821             AddNodeInput(quantized_input_name + ":0", &quantized_main_node);
822             ++quantized_input_index;
823           }
824         }
825         if (op_info.min_max_order == QuantizedOpInfo::CONTIGUOUS_MIN_MAX) {
826           for (const string& quantized_input_name : quantized_input_names) {
827             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
828             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
829           }
830         } else {
831           for (const string& quantized_input_name : quantized_input_names) {
832             AddNodeInput(quantized_input_name + ":1", &quantized_main_node);
833           }
834           for (const string& quantized_input_name : quantized_input_names) {
835             AddNodeInput(quantized_input_name + ":2", &quantized_main_node);
836           }
837         }
838         new_nodes->push_back(quantized_main_node);
839 
840         string eight_bit_node_name;
841         if (op_info.output_bit_depth == DT_QINT32) {
842           // Shrink the range of the output down from 32 bits to 8.
843           string requantize_min_input;
844           string requantize_max_input;
845           if (has_fallback_range) {
846             // Use constant values for the min/max range if they were given.
847             NodeDef fallback_min_node;
848             fallback_min_node.set_op("Const");
849             fallback_min_node.set_name(quantized_main_node.name() +
850                                        "/fallback_min");
851             SetNodeAttr("dtype", DT_FLOAT, &fallback_min_node);
852             Tensor fallback_min_tensor(DT_FLOAT, {});
853             fallback_min_tensor.flat<float>()(0) = fallback_min;
854             SetNodeTensorAttr<float>("value", fallback_min_tensor,
855                                      &fallback_min_node);
856             new_nodes->push_back(fallback_min_node);
857 
858             NodeDef fallback_max_node;
859             fallback_max_node.set_op("Const");
860             fallback_max_node.set_name(quantized_main_node.name() +
861                                        "/fallback_max");
862             SetNodeAttr("dtype", DT_FLOAT, &fallback_max_node);
863             Tensor fallback_max_tensor(DT_FLOAT, {});
864             fallback_max_tensor.flat<float>()(0) = fallback_max;
865             SetNodeTensorAttr<float>("value", fallback_max_tensor,
866                                      &fallback_max_node);
867             new_nodes->push_back(fallback_max_node);
868 
869             requantize_min_input = fallback_min_node.name();
870             requantize_max_input = fallback_max_node.name();
871           } else {
872             // Otherwise dynamically measure the range each time.
873             NodeDef requant_range_node;
874             requant_range_node.set_op("RequantizationRange");
875             requant_range_node.set_name(quantized_main_node.name() +
876                                         "/requant_range");
877             SetNodeAttr("Tinput", DT_QINT32, &requant_range_node);
878             AddNodeInput(quantized_main_node.name() + ":0",
879                          &requant_range_node);
880             AddNodeInput(quantized_main_node.name() + ":1",
881                          &requant_range_node);
882             AddNodeInput(quantized_main_node.name() + ":2",
883                          &requant_range_node);
884             new_nodes->push_back(requant_range_node);
885 
886             requantize_min_input = requant_range_node.name() + ":0";
887             requantize_max_input = requant_range_node.name() + ":1";
888           }
889           NodeDef requantize_node;
890           requantize_node.set_op("Requantize");
891           requantize_node.set_name(quantized_main_node.name() + "/requantize");
892           SetNodeAttr("Tinput", DT_QINT32, &requantize_node);
893           SetNodeAttr("out_type", DT_QUINT8, &requantize_node);
894           AddNodeInput(quantized_main_node.name() + ":0", &requantize_node);
895           AddNodeInput(quantized_main_node.name() + ":1", &requantize_node);
896           AddNodeInput(quantized_main_node.name() + ":2", &requantize_node);
897           AddNodeInput(requantize_min_input, &requantize_node);
898           AddNodeInput(requantize_max_input, &requantize_node);
899           new_nodes->push_back(requantize_node);
900           eight_bit_node_name = requantize_node.name();
901         } else {
902           eight_bit_node_name = quantized_main_node.name();
903         }
904 
905         // Convert the 8-bit result back into float for the final output.
906         NodeDef dequantize_node;
907         dequantize_node.set_op("Dequantize");
908         dequantize_node.set_name(float_node.name());
909         SetNodeAttr("T", DT_QUINT8, &dequantize_node);
910         SetNodeAttr("mode", "MIN_FIRST", &dequantize_node);
911         AddNodeInput(eight_bit_node_name + ":0", &dequantize_node);
912         AddNodeInput(eight_bit_node_name + ":1", &dequantize_node);
913         AddNodeInput(eight_bit_node_name + ":2", &dequantize_node);
914         new_nodes->push_back(dequantize_node);
915 
916         return Status::OK();
917       },
918       {}, &quantized_graph_def));
919   TF_RETURN_IF_ERROR(IsGraphValid(quantized_graph_def));
920 
921   // If we've ended up with two Requantize ops in a row (for example if there
922   // was a Conv2D feeding into a FakeQuantWithMinMaxVars) merge them together,
923   // using the trained range from the second op.
924   GraphDef merged_graph_def;
925   TF_RETURN_IF_ERROR(MergeAdjacentRequantizes(quantized_graph_def, context,
926                                               &merged_graph_def));
927   TF_RETURN_IF_ERROR(IsGraphValid(merged_graph_def));
928 
929   // There can be duplicate quantize nodes if multiple ops pull from a single
930   // input, which makes it harder to remove redundant ones, so strip them out.
931   GraphDef deduped_graph_def;
932   TF_RETURN_IF_ERROR(
933       MergeDuplicateNodes(merged_graph_def, context, &deduped_graph_def));
934   TF_RETURN_IF_ERROR(IsGraphValid(deduped_graph_def));
935 
936   // Look for Dequantizes that immediately go into Quantizes, and remove them
937   // since the two together cancel each other out. This allows us to keep the
938   // data flow in eight bit where two adjacent ops are in eight bit, but still
939   // keep interoperability with float ops.
940   TF_RETURN_IF_ERROR(RemoveRedundantQuantizations(deduped_graph_def, context,
941                                                   output_graph_def));
942   TF_RETURN_IF_ERROR(IsGraphValid(*output_graph_def));
943 
944   return Status::OK();
945 }
946 
947 REGISTER_GRAPH_TRANSFORM("quantize_nodes", QuantizeNodes);
948 
949 REGISTER_GRAPH_TRANSFORM("merge_duplicate_nodes", MergeDuplicateNodes);
950 
951 }  // namespace graph_transforms
952 }  // namespace tensorflow
953