• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/quantize_training.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <set>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/graph_constructor.h"
25 #include "tensorflow/core/common_runtime/memory_types.h"
26 #include "tensorflow/core/framework/log_memory.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/graph/algorithm.h"
29 #include "tensorflow/core/graph/node_builder.h"
30 #include "tensorflow/core/graph/subgraph.h"
31 #include "tensorflow/core/lib/strings/strcat.h"
32 #include "tensorflow/core/public/session_options.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 // TODO(suharshs): If desired, make these values configurable.
38 const uint32 kAllowedInputs = 2;
39 const float kEMADecay = 0.999;
40 
41 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
42 const auto* nodes_to_rewrite =
43     new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"};
44 
45 // Contains necessary parameters to convert an edge.
46 struct EdgeToConvert {
47   // edge is not owned here.
48   const Edge* edge;
49   int32 num_bits;
50   bool signed_input;
51   bool range_given;
52   float input_min;
53   float input_max;
54 
EdgeToConverttensorflow::__anoneef3810f0111::EdgeToConvert55   EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
56                 float max)
57       : edge(e),
58         num_bits(bits),
59         signed_input(sign),
60         range_given(range),
61         input_min(min),
62         input_max(max) {}
63 };
64 
65 // Decide if a node is in backward pass by checking if its name is led by
66 // "gradients".
67 // TODO(jmchen): Make this check more robust as it is not guaranteed that the
68 // forward node will not be named with a leading "gradients".
IsGradientNode(const Graph * graph,const Node * node)69 inline bool IsGradientNode(const Graph* graph, const Node* node) {
70   static const string tag = "gradients";
71   return (node->name().compare(0, tag.size(), tag) == 0);
72 }
73 
74 // Find the type of the input to set the parameters for the
75 // quantize_and_dequantize op.
76 // Returns true if the root tensor op type is known, false otherwise.
FindType(const Graph * graph,const Node * node,bool * signed_input,bool * range_given,float * input_min,float * input_max)77 bool FindType(const Graph* graph, const Node* node, bool* signed_input,
78               bool* range_given, float* input_min, float* input_max) {
79   const string& src_op = node->type_string();
80   if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") {
81     *signed_input = true;
82     *range_given = false;
83   } else if (src_op == "Relu") {
84     // Range is not given for Relu.
85     *signed_input = false;
86     *range_given = false;
87   } else if (src_op == "Relu6") {
88     // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
89     // actual activations are somewhere in within this range, we can quantize
90     // this even further. This is true for other activations like Sigmoid6 too.
91     *signed_input = false;
92     *range_given = true;
93     *input_min = 0;
94     *input_max = 6;
95   } else if (src_op == "Sigmoid") {
96     *signed_input = false;
97     *range_given = true;
98     *input_min = 0;
99     *input_max = 1;
100   } else if (src_op == "Tanh") {
101     *signed_input = true;
102     *range_given = true;
103     *input_min = -1;
104     *input_max = 1;
105   } else if (src_op == "Reshape" || src_op == "ConcatV2") {
106     // Reshape has 2 inputs and the first one is the tensor.
107     // ConcatV2 has many inputs but they should all have the same activation
108     // function (i.e. Inception). So we just recurse on the first input.
109     for (const Edge* edge : node->in_edges()) {
110       if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
111         FindType(graph, edge->src(), signed_input, range_given, input_min,
112                  input_max);
113       }
114     }
115   } else if (src_op == "Identity" || src_op == "MaxPool" ||
116              src_op == "AvgPool" || src_op == "MaxPool3D" ||
117              src_op == "AvgPool3D") {
118     // All these Ops only have 1 data input.
119     for (const Edge* edge : node->in_edges()) {
120       if (edge->src_output() != Graph::kControlSlot) {
121         FindType(graph, edge->src(), signed_input, range_given, input_min,
122                  input_max);
123       }
124     }
125   } else {
126     // Unknown type, could be the model input examples.
127     // TODO(jmchen): Set the params for input with user's hint.
128     *signed_input = true;
129     *range_given = false;
130     return false;
131   }
132 
133   return true;
134 }
135 
136 // Find the Save op and inputs.
FindSaveOp(const Graph * graph,Node ** save_op,std::vector<const Edge * > * in_edges,bool * found)137 Status FindSaveOp(const Graph* graph, Node** save_op,
138                   std::vector<const Edge*>* in_edges, bool* found) {
139   *found = false;
140   for (Node* node : graph->op_nodes()) {
141     if (node->type_string() == "SaveV2") {
142       // We found multiple save ops.
143       if (*found) {
144         return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
145       }
146       *save_op = node;
147       *found = true;
148       TF_RETURN_IF_ERROR(node->input_edges(in_edges));
149     }
150   }
151   return Status::OK();
152 }
153 
FindRestoreAllOp(const Graph * graph,StringPiece save_prefix)154 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
155   for (Node* node : graph->op_nodes()) {
156     // The restore_all op should have the same prefix of the save_op.
157     if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
158       return node;
159     }
160   }
161   return nullptr;
162 }
163 
164 // Strips the last "/suffix" from a name.
165 // We use this to construct the name of restore ops in the same way they are
166 // constructed by the Saver.
GetNodeNamePrefix(const Node * node)167 StringPiece GetNodeNamePrefix(const Node* node) {
168   StringPiece name = node->name();
169   return name.substr(0, name.rfind('/'));
170 }
171 
FillStringTensor(Tensor * dst,const Tensor & src)172 void FillStringTensor(Tensor* dst, const Tensor& src) {
173   auto dst_flat = dst->flat<tstring>();
174   auto src_flat = src.flat<tstring>();
175   for (int i = 0; i < src.NumElements(); i++) {
176     dst_flat(i) = src_flat(i);
177   }
178 }
179 
180 // Add the added_variables as an inputs to the Save op.
181 // We change the inputs of the SaveV2 op to include the names of the added
182 // variables. We also add the variables as inputs to the save op.
ConnectVariablesToSaveOp(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & added_variables)183 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
184                                 const std::vector<const Edge*>& in_edges,
185                                 const std::vector<Node*>& added_variables) {
186   Node* tensor_names_op = in_edges[1]->src();
187   Node* shape_and_slices_op = in_edges[2]->src();
188 
189   // Get the tensor_names and shape_and_slices tensors from the const op.
190   Tensor tensor_names;
191   Tensor shape_and_slices;
192   TF_RETURN_IF_ERROR(
193       GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
194   TF_RETURN_IF_ERROR(
195       GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
196 
197   int tn_size = tensor_names.NumElements();
198   int var_size = added_variables.size();
199 
200   // Create a new save_op that has inputs to all the new variables.
201   NodeBuilder save_op_builder =
202       NodeBuilder(save_op->name(), save_op->type_string());
203   // The first three inputs are prefix, tensor_names, and shapes_and_slices.
204   for (int i = 0; i < 3; i++) {
205     save_op_builder = save_op_builder.Input(in_edges[i]->src());
206   }
207   std::vector<NodeBuilder::NodeOut> var_nodeouts;
208   var_nodeouts.reserve(tn_size + var_size);
209   // The rest of the inputs need to be used the construct the tensor list arg.
210   for (int i = 3; i < in_edges.size(); i++) {
211     var_nodeouts.emplace_back(in_edges[i]->src());
212   }
213 
214   // Add the new values to the tensors and the op input.
215   Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
216   Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
217   FillStringTensor(&new_tensor_names, tensor_names);
218   FillStringTensor(&new_shape_and_slices, shape_and_slices);
219   for (int i = 0; i < var_size; i++) {
220     Node* var = added_variables[i];
221     new_tensor_names.flat<tstring>()(tn_size + i) = var->name();
222     new_shape_and_slices.flat<tstring>()(tn_size + i) = "";
223     var_nodeouts.emplace_back(var);
224   }
225   save_op_builder = save_op_builder.Input(var_nodeouts);
226 
227   // Update the attrs.
228   tensor_names_op->AddAttr("value", new_tensor_names);
229   shape_and_slices_op->AddAttr("value", new_shape_and_slices);
230 
231   // Remove the old save_op and add the new one.
232   Node* new_save_op;
233   TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
234   // Add outputs to the new_save_op, all outputs are control edges.
235   for (const Edge* edge : save_op->out_edges()) {
236     graph->AddControlEdge(new_save_op, edge->dst());
237   }
238   graph->RemoveNode(save_op);
239 
240   return Status::OK();
241 }
242 
243 // Add a restore subgraph for each variable and connect to the restore_all op.
244 // For each variable we add the following subgraph:
245 //           Assign----restore_all
246 //          |      |
247 //   RestoreV2    Variable
AddRestoreVariableSubgraphs(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & variables)248 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
249                                    const std::vector<const Edge*>& in_edges,
250                                    const std::vector<Node*>& variables) {
251   Node* prefix_op = in_edges[0]->src();
252   StringPiece name_prefix = GetNodeNamePrefix(save_op);
253   Node* restore_all = FindRestoreAllOp(graph, name_prefix);
254   if (restore_all == nullptr) {
255     return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
256   }
257   const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
258   const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
259   for (Node* var : variables) {
260     // Add an extra prefix after calling graph->NewName because the "unique"
261     // name may conflict with names generated for Send nodes.
262     // TODO(b/77547936): fix this more generally and get rid of the extra prefix
263     // here.
264     string new_restore_op_name =
265         strings::StrCat(graph->NewName(restore_op_name), "_qt");
266     string new_assign_op_name =
267         strings::StrCat(graph->NewName(assign_op_name), "_qt");
268     string tensor_names_op_name =
269         strings::StrCat(new_restore_op_name, "/tensor_names");
270     string shape_and_slices_op_name =
271         strings::StrCat(new_restore_op_name, "/shape_and_slices");
272 
273     // Construct the tensor_names input with the variable name.
274     Node* tensor_names;
275     Tensor tensor_names_val(DT_STRING, TensorShape({1}));
276     tensor_names_val.flat<tstring>()(0) = var->name();
277     TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
278                            .Attr("dtype", DT_STRING)
279                            .Attr("value", tensor_names_val)
280                            .Finalize(graph, &tensor_names));
281 
282     // Construct the shape_and_slices input with empty string.
283     Node* shape_and_slices;
284     Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
285     shape_and_slices_val.flat<tstring>()(0) = "";
286     TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
287                            .Attr("dtype", DT_STRING)
288                            .Attr("value", shape_and_slices_val)
289                            .Finalize(graph, &shape_and_slices));
290 
291     // Build the new Restore op for this variable.
292     Node* restore_op;
293     TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
294                            .Input(prefix_op)
295                            .Input(tensor_names)
296                            .Input(shape_and_slices)
297                            .Attr("dtypes", {DT_FLOAT})
298                            .Finalize(graph, &restore_op));
299 
300     // Create Assign op, attaching the variable and Restore op to it.
301     Node* assign_op;
302     TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
303                            .Input(var)
304                            .Input(restore_op)
305                            .Finalize(graph, &assign_op));
306 
307     // Add a control edge from the assign op to restore_all op.
308     graph->AddControlEdge(assign_op, restore_all);
309   }
310   return Status::OK();
311 }
312 
313 // Adds new variables to save and restore ops matching the Save and Restore
314 // graphs created in tensorflow/python/training/saver.py.
AddSaveAndRestore(Graph * graph,const std::vector<Node * > & variables)315 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
316   Node* save_op = nullptr;
317   std::vector<const Edge*> in_edges;
318   bool found = false;
319   TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
320   if (found) {
321     TF_RETURN_IF_ERROR(
322         AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
323     TF_RETURN_IF_ERROR(
324         ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
325   }
326   return Status::OK();
327 }
328 
329 // Sets output to the Node that computes reduction axes corresponding to all
330 // dimensions of input and return.
MakeReductionAxes(Graph * graph,string name_prefix,Node * input,Node ** output)331 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
332                          Node** output) {
333   name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
334   Node* start;
335   Tensor zero_tensor(DT_INT32, TensorShape());
336   zero_tensor.flat<int32>()(0) = 0;
337   TF_RETURN_IF_ERROR(
338       NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
339           .Attr("dtype", DT_INT32)
340           .Attr("value", zero_tensor)
341           .Finalize(graph, &start));
342   Node* delta;
343   Tensor one_tensor(DT_INT32, TensorShape());
344   one_tensor.flat<int32>()(0) = 1;
345   TF_RETURN_IF_ERROR(
346       NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
347           .Attr("dtype", DT_INT32)
348           .Attr("value", one_tensor)
349           .Finalize(graph, &delta));
350   Node* rank;
351   TF_RETURN_IF_ERROR(
352       NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
353           .Input(input)
354           .Finalize(graph, &rank));
355   TF_RETURN_IF_ERROR(
356       NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
357           .Input(start)
358           .Input(rank)
359           .Input(delta)
360           .Finalize(graph, output));
361   return Status::OK();
362 }
363 
364 // Computes the exponential moving average of input, updated in update_variable.
MakeExponentialMovingAverage(Graph * graph,string name_prefix,const NodeBuilder::NodeOut & input,Node * decay,Node * update_variable,Node ** assign_value)365 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
366                                     const NodeBuilder::NodeOut& input,
367                                     Node* decay, Node* update_variable,
368                                     Node** assign_value) {
369   // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
370   name_prefix = strings::StrCat(name_prefix, "/EMA");
371   Node* one;
372   Tensor one_tensor(DT_FLOAT, TensorShape());
373   one_tensor.flat<float>()(0) = 1.0;
374   TF_RETURN_IF_ERROR(
375       NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
376           .Attr("dtype", DT_FLOAT)
377           .Attr("value", one_tensor)
378           .Finalize(graph, &one));
379   Node* decay_complement;
380   TF_RETURN_IF_ERROR(
381       NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
382           .Input(one)
383           .Input(decay)
384           .Finalize(graph, &decay_complement));
385 
386   Node* value_diff;
387   TF_RETURN_IF_ERROR(
388       NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
389           .Input(update_variable)
390           .Input(input)
391           .Finalize(graph, &value_diff));
392   Node* update_value;
393   TF_RETURN_IF_ERROR(
394       NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
395           .Input(value_diff)
396           .Input(decay_complement)
397           .Finalize(graph, &update_value));
398 
399   TF_RETURN_IF_ERROR(
400       NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
401           .Input(update_variable)
402           .Input(update_value)
403           .Finalize(graph, assign_value));
404   return Status::OK();
405 }
406 
407 // Creates an automatically initialized exponential moving average variable.
408 // This uses a switch op to assign a value to the variable on the first run,
409 // and update with the moving average for all other runs:
410 //                   init_val
411 //                      |
412 //      var--is_init--switch
413 //       |      true /      \ false
414 //       |          |        |
415 //       |         EMA    init_val
416 //       |           \      /
417 //       +----------- assign
MakeInitializedEMAVariable(Graph * graph,const string & name,Node * decay,Node * init_val,std::vector<Node * > * added_variables,Node ** var)418 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
419                                   Node* init_val,
420                                   std::vector<Node*>* added_variables,
421                                   Node** var) {
422   // TODO(suharshs): Update this to use ResourceVariables when they are ready.
423   TF_RETURN_IF_ERROR(
424       NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
425           .Attr("shape", TensorShape())
426           .Attr("dtype", DT_FLOAT)
427           .Finalize(graph, var));
428   added_variables->push_back(*var);
429 
430   Node* is_initialized;
431   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
432                                  "IsVariableInitialized")
433                          .Input(*var)
434                          .Finalize(graph, &is_initialized));
435   Node* switch_node;
436   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
437                          .Input(init_val)
438                          .Input(is_initialized)
439                          .Finalize(graph, &switch_node));
440   NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
441   NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
442 
443   Node* ema_value;
444   TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
445                                                   decay, *var, &ema_value));
446 
447   Node* assign_value;
448   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
449                          .Input({output_false, ema_value})
450                          .Finalize(graph, &assign_value));
451 
452   TF_RETURN_IF_ERROR(
453       NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
454           .Input(*var)
455           .Input(assign_value)
456           .Finalize(graph, var));
457   return Status::OK();
458 }
459 
460 // Computes the min and max EMA of input and stores them in min_var and max_var.
MakeEMAMinMaxVars(Graph * graph,const string & name_prefix,Node * input,std::vector<Node * > * added_variables,Node ** min_var,Node ** max_var)461 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
462                          std::vector<Node*>* added_variables, Node** min_var,
463                          Node** max_var) {
464   // TODO(suharshs): The decay will be constant, so we could make only one for
465   // all quantize_and_dequantize ops to share, this would have to live outside
466   // this function.
467   Tensor decay_tensor(DT_FLOAT, TensorShape());
468   decay_tensor.flat<float>()(0) = kEMADecay;
469   Node* decay;
470   TF_RETURN_IF_ERROR(
471       NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
472           .Attr("dtype", DT_FLOAT)
473           .Attr("value", decay_tensor)
474           .Finalize(graph, &decay));
475 
476   Node* reduction_axes;
477   TF_RETURN_IF_ERROR(
478       MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
479   Node* min;
480   string min_name = strings::StrCat(name_prefix, "/Min");
481   TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
482                          .Input(input)
483                          .Input(reduction_axes)
484                          .Finalize(graph, &min));
485   Node* max;
486   string max_name = strings::StrCat(name_prefix, "/Max");
487   TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
488                          .Input(input)
489                          .Input(reduction_axes)
490                          .Finalize(graph, &max));
491   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
492                                                 added_variables, min_var));
493   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
494                                                 added_variables, max_var));
495   return Status::OK();
496 }
497 
498 // Makes an input min and max constant if the range is given. Otherwise, makes
499 // min and max variables that are updated by an EMA.
MakeInputMinMax(Graph * graph,const string & name_prefix,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** input_min,Node ** input_max)500 Status MakeInputMinMax(Graph* graph, const string& name_prefix,
501                        const EdgeToConvert& edge,
502                        std::vector<Node*>* added_variables, Node** input_min,
503                        Node** input_max) {
504   if (edge.range_given) {
505     // Make constant nodes for the input_min and input_max if the range is
506     // provided.
507     Tensor input_min_tensor(DT_FLOAT, TensorShape());
508     input_min_tensor.flat<float>()(0) = edge.input_min;
509     TF_RETURN_IF_ERROR(
510         NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
511             .Attr("dtype", DT_FLOAT)
512             .Attr("value", input_min_tensor)
513             .Finalize(graph, input_min));
514     Tensor input_max_tensor(DT_FLOAT, TensorShape());
515     input_max_tensor.flat<float>()(0) = edge.input_max;
516     TF_RETURN_IF_ERROR(
517         NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
518             .Attr("dtype", DT_FLOAT)
519             .Attr("value", input_max_tensor)
520             .Finalize(graph, input_max));
521   } else {
522     // If the range is not given, estimate the range with EMA variables.
523     TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
524                                          added_variables, input_min,
525                                          input_max));
526   }
527 
528   return Status::OK();
529 }
530 
531 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op
532 // (and required input nodes) based on edge.
533 // The result is stored in convert_node.
MakeQuantizeOp(Graph * graph,const string & name_prefix,const string & quant_op_type,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** convert_node)534 Status MakeQuantizeOp(Graph* graph, const string& name_prefix,
535                       const string& quant_op_type, const EdgeToConvert& edge,
536                       std::vector<Node*>* added_variables,
537                       Node** convert_node) {
538   Node* input_min;
539   Node* input_max;
540   TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
541                                      &input_min, &input_max));
542   string quant_name = strings::StrCat(name_prefix, "/", quant_op_type);
543   if (quant_op_type == "QuantizeAndDequantizeV2") {
544     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
545                            .Input(edge.edge->src())
546                            .Input(input_min)
547                            .Input(input_max)
548                            .Attr("signed_input", edge.signed_input)
549                            .Attr("num_bits", edge.num_bits)
550                            .Attr("range_given", true)
551                            .Finalize(graph, convert_node));
552   } else if (quant_op_type == "FakeQuantWithMinMaxVars") {
553     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
554                            .Input(edge.edge->src())
555                            .Input(input_min)
556                            .Input(input_max)
557                            .Attr("num_bits", edge.num_bits)
558                            .Finalize(graph, convert_node));
559   } else {
560     return errors::InvalidArgument("Unknown quant op type: ", quant_op_type);
561   }
562   return Status::OK();
563 }
564 
565 // Insert conversion op, connect it to the graph and remove the old edge.
ProcessTargetEdges(Graph * graph,const string & quant_op_type,const std::vector<EdgeToConvert> & target_edges)566 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type,
567                           const std::vector<EdgeToConvert>& target_edges) {
568   // Remember previously converted ops to avoid duplicated conversion on the
569   // same input.
570   std::unordered_map<string, Node*, StringPieceHasher> name_index;
571   std::vector<Node*> added_variables;
572   for (const EdgeToConvert edge : target_edges) {
573     Node* convert_node;
574     string name_prefix = edge.edge->src()->name();
575 
576     auto iter = name_index.find(name_prefix);
577     if (iter == name_index.end()) {
578       TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge,
579                                         &added_variables, &convert_node));
580       name_index[name_prefix] = convert_node;
581     } else {
582       convert_node = iter->second;
583     }
584 
585     graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
586     graph->RemoveEdge(edge.edge);
587   }
588 
589   TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
590 
591   return Status::OK();
592 }
593 
594 }  // namespace
595 
DoQuantizeTraining(int32 num_bits,const string & quant_op_type,Graph * graph)596 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
597                           Graph* graph) {
598   if (graph == nullptr) {
599     return errors::InvalidArgument("Cannot accept empty graph pointer.");
600   }
601 
602   if (num_bits < 1 || num_bits > 63) {
603     return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
604                               num_bits);
605   }
606   int potential_input = 0;
607   std::vector<EdgeToConvert> target_edges;
608   for (Node* node : graph->nodes()) {
609     if (nodes_to_rewrite->find(node->type_string()) !=
610             nodes_to_rewrite->end() &&
611         !IsGradientNode(graph, node)) {
612       // Find out which types are the inputs and convert them accordingly.
613       // 1. Const/Variable OP: This is quantized as signed tensors with no given
614       // range.
615       // 2. Activation OP: Set the range accordingly for different types of
616       // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
617       // 3. Identity OP: The quantization parameters depend on its input.
618       // 4. Pooling OPs: various pooling ops. Also depends on its input.
619       // 5. Reshape OP: Also depends on the first input to this op.
620       // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
621       // model input. However, if there are >1 unknown ops, then returns an
622       // error for now to avoid unexpected behavior.
623       // Note: The list above might not be a complete list. Please let us
624       // know if you see the error so we can handle your case.
625       for (const Edge* edge : node->in_edges()) {
626         if (edge->src_output() == Graph::kControlSlot) {
627           // Skip the control dependency input.
628           continue;
629         } else {
630           bool signed_input = false;
631           bool range_given = false;
632           float input_min = 0;
633           float input_max = 0;
634           bool known_op = FindType(graph, edge->src(), &signed_input,
635                                    &range_given, &input_min, &input_max);
636           if (!known_op) {
637             // Unknown op is considered as input.
638             potential_input++;
639             if (potential_input > kAllowedInputs) {
640               return errors::Unimplemented(
641                   "Found an unknown op: ", edge->src()->name(),
642                   " with type: ", edge->src()->type_string(),
643                   "; Unknown ops are considered as model input for now and "
644                   "only ",
645                   kAllowedInputs, " inputs are supported currently.");
646             }
647           }
648 
649           target_edges.emplace_back(EdgeToConvert(
650               edge, num_bits, signed_input, range_given, input_min, input_max));
651         }
652       }
653     }
654   }
655 
656   TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges));
657 
658   return Status::OK();
659 }
660 
DoQuantizeTrainingOnGraphDef(const GraphDef & input_graphdef,int32 num_bits,const string & quant_op_type,GraphDef * result_graphdef)661 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
662                                     int32 num_bits, const string& quant_op_type,
663                                     GraphDef* result_graphdef) {
664   Graph graph(OpRegistry::Global());
665   GraphConstructorOptions opts;
666   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
667 
668   // Call the rewriter on the graph.
669   TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
670 
671   // Convert the result graph back to a GraphDef.
672   graph.ToGraphDef(result_graphdef);
673   return Status::OK();
674 }
675 
DoQuantizeTrainingOnSerializedGraphDef(const string & input_graph_string,int32 num_bits,const string & quant_op_type,string * result_graph_string)676 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
677                                               int32 num_bits,
678                                               const string& quant_op_type,
679                                               string* result_graph_string) {
680   // First create the graph from the GraphDef.
681   GraphDef input_graphdef;
682   if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
683     return errors::InvalidArgument(
684         "input_graph_string is not a serialized GraphDef protocol buffer");
685   }
686   GraphDef output_graphdef;
687   TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
688       input_graphdef, num_bits, quant_op_type, &output_graphdef));
689 
690   if (!output_graphdef.SerializeToString(result_graph_string)) {
691     return errors::Internal(
692         "quantize training transformation resulted in invalid GraphDef");
693   }
694   return Status::OK();
695 }
696 
697 }  // namespace tensorflow
698