• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <atomic>
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/graph/quantize_training.h"
23 
24 #include "tensorflow/core/common_runtime/executor.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/log_memory.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/graph/node_builder.h"
32 #include "tensorflow/core/graph/subgraph.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/public/session_options.h"
35 
36 namespace tensorflow {
37 namespace {
38 
39 // TODO(suharshs): If desired, make these values configurable.
40 const uint32 kAllowedInputs = 2;
41 const float kEMADecay = 0.999;
42 
43 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
44 const auto* nodes_to_rewrite =
45     new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"};
46 
47 // Contains necessary parameters to convert an edge.
48 struct EdgeToConvert {
49   // edge is not owned here.
50   const Edge* edge;
51   int32 num_bits;
52   bool signed_input;
53   bool range_given;
54   float input_min;
55   float input_max;
56 
EdgeToConverttensorflow::__anon414e2d190111::EdgeToConvert57   EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
58                 float max)
59       : edge(e),
60         num_bits(bits),
61         signed_input(sign),
62         range_given(range),
63         input_min(min),
64         input_max(max) {}
65 };
66 
67 // Decide if a node is in backward pass by checking if its name is led by
68 // "gradients".
69 // TODO(jmchen): Make this check more robust as it is not guaranteed that the
70 // forward node will not be named with a leading "gradients".
IsGradientNode(const Graph * graph,const Node * node)71 inline bool IsGradientNode(const Graph* graph, const Node* node) {
72   static const string tag = "gradients";
73   return (node->name().compare(0, tag.size(), tag) == 0);
74 }
75 
76 // Find the type of the input to set the parameters for the
77 // quantize_and_dequantize op.
78 // Returns true if the root tensor op type is known, false otherwise.
FindType(const Graph * graph,const Node * node,bool * signed_input,bool * range_given,float * input_min,float * input_max)79 bool FindType(const Graph* graph, const Node* node, bool* signed_input,
80               bool* range_given, float* input_min, float* input_max) {
81   const string& src_op = node->type_string();
82   if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") {
83     *signed_input = true;
84     *range_given = false;
85   } else if (src_op == "Relu") {
86     // Range is not given for Relu.
87     *signed_input = false;
88     *range_given = false;
89   } else if (src_op == "Relu6") {
90     // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
91     // actual activations are somewhere in within this range, we can quantize
92     // this even further. This is true for other activations like Sigmoid6 too.
93     *signed_input = false;
94     *range_given = true;
95     *input_min = 0;
96     *input_max = 6;
97   } else if (src_op == "Sigmoid") {
98     *signed_input = false;
99     *range_given = true;
100     *input_min = 0;
101     *input_max = 1;
102   } else if (src_op == "Tanh") {
103     *signed_input = true;
104     *range_given = true;
105     *input_min = -1;
106     *input_max = 1;
107   } else if (src_op == "Reshape" || src_op == "ConcatV2") {
108     // Reshape has 2 inputs and the first one is the tensor.
109     // ConcatV2 has many inputs but they should all have the same activation
110     // function (i.e. Inception). So we just recurse on the first input.
111     for (const Edge* edge : node->in_edges()) {
112       if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
113         FindType(graph, edge->src(), signed_input, range_given, input_min,
114                  input_max);
115       }
116     }
117   } else if (src_op == "Identity" || src_op == "MaxPool" ||
118              src_op == "AvgPool" || src_op == "MaxPool3D" ||
119              src_op == "AvgPool3D") {
120     // All these Ops only have 1 data input.
121     for (const Edge* edge : node->in_edges()) {
122       if (edge->src_output() != Graph::kControlSlot) {
123         FindType(graph, edge->src(), signed_input, range_given, input_min,
124                  input_max);
125       }
126     }
127   } else {
128     // Unknown type, could be the model input examples.
129     // TODO(jmchen): Set the params for input with user's hint.
130     *signed_input = true;
131     *range_given = false;
132     return false;
133   }
134 
135   return true;
136 }
137 
138 // Find the Save op and inputs.
FindSaveOp(const Graph * graph,Node ** save_op,std::vector<const Edge * > * in_edges,bool * found)139 Status FindSaveOp(const Graph* graph, Node** save_op,
140                   std::vector<const Edge*>* in_edges, bool* found) {
141   *found = false;
142   for (Node* node : graph->op_nodes()) {
143     if (node->type_string() == "SaveV2") {
144       // We found multiple save ops.
145       if (*found) {
146         return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
147       }
148       *save_op = node;
149       *found = true;
150       TF_RETURN_IF_ERROR(node->input_edges(in_edges));
151     }
152   }
153   return Status::OK();
154 }
155 
FindRestoreAllOp(const Graph * graph,StringPiece save_prefix)156 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
157   for (Node* node : graph->op_nodes()) {
158     // The restore_all op should have the same prefix of the save_op.
159     if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
160       return node;
161     }
162   }
163   return nullptr;
164 }
165 
166 // Strips the last "/suffix" from a name.
167 // We use this to construct the name of restore ops in the same way they are
168 // constructed by the Saver.
GetNodeNamePrefix(const Node * node)169 StringPiece GetNodeNamePrefix(const Node* node) {
170   StringPiece name = node->name();
171   return name.substr(0, name.rfind('/'));
172 }
173 
FillStringTensor(Tensor * dst,const Tensor & src)174 void FillStringTensor(Tensor* dst, const Tensor& src) {
175   auto dst_flat = dst->flat<tstring>();
176   auto src_flat = src.flat<tstring>();
177   for (int i = 0; i < src.NumElements(); i++) {
178     dst_flat(i) = src_flat(i);
179   }
180 }
181 
182 // Add the added_variables as an inputs to the Save op.
183 // We change the inputs of the SaveV2 op to include the names of the added
184 // variables. We also add the variables as inputs to the save op.
ConnectVariablesToSaveOp(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & added_variables)185 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
186                                 const std::vector<const Edge*>& in_edges,
187                                 const std::vector<Node*>& added_variables) {
188   Node* tensor_names_op = in_edges[1]->src();
189   Node* shape_and_slices_op = in_edges[2]->src();
190 
191   // Get the tensor_names and shape_and_slices tensors from the const op.
192   Tensor tensor_names;
193   Tensor shape_and_slices;
194   TF_RETURN_IF_ERROR(
195       GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
196   TF_RETURN_IF_ERROR(
197       GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
198 
199   int tn_size = tensor_names.NumElements();
200   int var_size = added_variables.size();
201 
202   // Create a new save_op that has inputs to all the new variables.
203   NodeBuilder save_op_builder =
204       NodeBuilder(save_op->name(), save_op->type_string());
205   // The first three inputs are prefix, tensor_names, and shapes_and_slices.
206   for (int i = 0; i < 3; i++) {
207     save_op_builder = save_op_builder.Input(in_edges[i]->src());
208   }
209   std::vector<NodeBuilder::NodeOut> var_nodeouts;
210   var_nodeouts.reserve(tn_size + var_size);
211   // The rest of the inputs need to be used the construct the tensor list arg.
212   for (int i = 3; i < in_edges.size(); i++) {
213     var_nodeouts.emplace_back(in_edges[i]->src());
214   }
215 
216   // Add the new values to the tensors and the op input.
217   Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
218   Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
219   FillStringTensor(&new_tensor_names, tensor_names);
220   FillStringTensor(&new_shape_and_slices, shape_and_slices);
221   for (int i = 0; i < var_size; i++) {
222     Node* var = added_variables[i];
223     new_tensor_names.flat<tstring>()(tn_size + i) = var->name();
224     new_shape_and_slices.flat<tstring>()(tn_size + i) = "";
225     var_nodeouts.emplace_back(var);
226   }
227   save_op_builder = save_op_builder.Input(var_nodeouts);
228 
229   // Update the attrs.
230   tensor_names_op->AddAttr("value", new_tensor_names);
231   shape_and_slices_op->AddAttr("value", new_shape_and_slices);
232 
233   // Remove the old save_op and add the new one.
234   Node* new_save_op;
235   TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
236   // Add outputs to the new_save_op, all outputs are control edges.
237   for (const Edge* edge : save_op->out_edges()) {
238     graph->AddControlEdge(new_save_op, edge->dst());
239   }
240   graph->RemoveNode(save_op);
241 
242   return Status::OK();
243 }
244 
245 // Add a restore subgraph for each variable and connect to the restore_all op.
246 // For each variable we add the following subgraph:
247 //           Assign----restore_all
248 //          |      |
249 //   RestoreV2    Variable
AddRestoreVariableSubgraphs(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & variables)250 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
251                                    const std::vector<const Edge*>& in_edges,
252                                    const std::vector<Node*>& variables) {
253   Node* prefix_op = in_edges[0]->src();
254   StringPiece name_prefix = GetNodeNamePrefix(save_op);
255   Node* restore_all = FindRestoreAllOp(graph, name_prefix);
256   if (restore_all == nullptr) {
257     return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
258   }
259   const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
260   const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
261   for (Node* var : variables) {
262     // Add an extra prefix after calling graph->NewName because the "unique"
263     // name may conflict with names generated for Send nodes.
264     // TODO(b/77547936): fix this more generally and get rid of the extra prefix
265     // here.
266     string new_restore_op_name =
267         strings::StrCat(graph->NewName(restore_op_name), "_qt");
268     string new_assign_op_name =
269         strings::StrCat(graph->NewName(assign_op_name), "_qt");
270     string tensor_names_op_name =
271         strings::StrCat(new_restore_op_name, "/tensor_names");
272     string shape_and_slices_op_name =
273         strings::StrCat(new_restore_op_name, "/shape_and_slices");
274 
275     // Construct the tensor_names input with the variable name.
276     Node* tensor_names;
277     Tensor tensor_names_val(DT_STRING, TensorShape({1}));
278     tensor_names_val.flat<tstring>()(0) = var->name();
279     TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
280                            .Attr("dtype", DT_STRING)
281                            .Attr("value", tensor_names_val)
282                            .Finalize(graph, &tensor_names));
283 
284     // Construct the shape_and_slices input with empty string.
285     Node* shape_and_slices;
286     Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
287     shape_and_slices_val.flat<tstring>()(0) = "";
288     TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
289                            .Attr("dtype", DT_STRING)
290                            .Attr("value", shape_and_slices_val)
291                            .Finalize(graph, &shape_and_slices));
292 
293     // Build the new Restore op for this variable.
294     Node* restore_op;
295     TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
296                            .Input(prefix_op)
297                            .Input(tensor_names)
298                            .Input(shape_and_slices)
299                            .Attr("dtypes", {DT_FLOAT})
300                            .Finalize(graph, &restore_op));
301 
302     // Create Assign op, attaching the variable and Restore op to it.
303     Node* assign_op;
304     TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
305                            .Input(var)
306                            .Input(restore_op)
307                            .Finalize(graph, &assign_op));
308 
309     // Add a control edge from the assign op to restore_all op.
310     graph->AddControlEdge(assign_op, restore_all);
311   }
312   return Status::OK();
313 }
314 
315 // Adds new variables to save and restore ops matching the Save and Restore
316 // graphs created in tensorflow/python/training/saver.py.
AddSaveAndRestore(Graph * graph,const std::vector<Node * > & variables)317 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
318   Node* save_op = nullptr;
319   std::vector<const Edge*> in_edges;
320   bool found = false;
321   TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
322   if (found) {
323     TF_RETURN_IF_ERROR(
324         AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
325     TF_RETURN_IF_ERROR(
326         ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
327   }
328   return Status::OK();
329 }
330 
331 // Sets output to the Node that computes reduction axes corresponding to all
332 // dimensions of input and return.
MakeReductionAxes(Graph * graph,string name_prefix,Node * input,Node ** output)333 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
334                          Node** output) {
335   name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
336   Node* start;
337   Tensor zero_tensor(DT_INT32, TensorShape());
338   zero_tensor.flat<int32>()(0) = 0;
339   TF_RETURN_IF_ERROR(
340       NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
341           .Attr("dtype", DT_INT32)
342           .Attr("value", zero_tensor)
343           .Finalize(graph, &start));
344   Node* delta;
345   Tensor one_tensor(DT_INT32, TensorShape());
346   one_tensor.flat<int32>()(0) = 1;
347   TF_RETURN_IF_ERROR(
348       NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
349           .Attr("dtype", DT_INT32)
350           .Attr("value", one_tensor)
351           .Finalize(graph, &delta));
352   Node* rank;
353   TF_RETURN_IF_ERROR(
354       NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
355           .Input(input)
356           .Finalize(graph, &rank));
357   TF_RETURN_IF_ERROR(
358       NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
359           .Input(start)
360           .Input(rank)
361           .Input(delta)
362           .Finalize(graph, output));
363   return Status::OK();
364 }
365 
366 // Computes the exponential moving average of input, updated in update_variable.
MakeExponentialMovingAverage(Graph * graph,string name_prefix,const NodeBuilder::NodeOut & input,Node * decay,Node * update_variable,Node ** assign_value)367 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
368                                     const NodeBuilder::NodeOut& input,
369                                     Node* decay, Node* update_variable,
370                                     Node** assign_value) {
371   // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
372   name_prefix = strings::StrCat(name_prefix, "/EMA");
373   Node* one;
374   Tensor one_tensor(DT_FLOAT, TensorShape());
375   one_tensor.flat<float>()(0) = 1.0;
376   TF_RETURN_IF_ERROR(
377       NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
378           .Attr("dtype", DT_FLOAT)
379           .Attr("value", one_tensor)
380           .Finalize(graph, &one));
381   Node* decay_complement;
382   TF_RETURN_IF_ERROR(
383       NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
384           .Input(one)
385           .Input(decay)
386           .Finalize(graph, &decay_complement));
387 
388   Node* value_diff;
389   TF_RETURN_IF_ERROR(
390       NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
391           .Input(update_variable)
392           .Input(input)
393           .Finalize(graph, &value_diff));
394   Node* update_value;
395   TF_RETURN_IF_ERROR(
396       NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
397           .Input(value_diff)
398           .Input(decay_complement)
399           .Finalize(graph, &update_value));
400 
401   TF_RETURN_IF_ERROR(
402       NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
403           .Input(update_variable)
404           .Input(update_value)
405           .Finalize(graph, assign_value));
406   return Status::OK();
407 }
408 
409 // Creates an automatically initialized exponential moving average variable.
410 // This uses a switch op to assign a value to the variable on the first run,
411 // and update with the moving average for all other runs:
412 //                   init_val
413 //                      |
414 //      var--is_init--switch
415 //       |      true /      \ false
416 //       |          |        |
417 //       |         EMA    init_val
418 //       |           \      /
419 //       +----------- assign
MakeInitializedEMAVariable(Graph * graph,const string & name,Node * decay,Node * init_val,std::vector<Node * > * added_variables,Node ** var)420 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
421                                   Node* init_val,
422                                   std::vector<Node*>* added_variables,
423                                   Node** var) {
424   // TODO(suharshs): Update this to use ResourceVariables when they are ready.
425   TF_RETURN_IF_ERROR(
426       NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
427           .Attr("shape", TensorShape())
428           .Attr("dtype", DT_FLOAT)
429           .Finalize(graph, var));
430   added_variables->push_back(*var);
431 
432   Node* is_initialized;
433   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
434                                  "IsVariableInitialized")
435                          .Input(*var)
436                          .Finalize(graph, &is_initialized));
437   Node* switch_node;
438   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
439                          .Input(init_val)
440                          .Input(is_initialized)
441                          .Finalize(graph, &switch_node));
442   NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
443   NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
444 
445   Node* ema_value;
446   TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
447                                                   decay, *var, &ema_value));
448 
449   Node* assign_value;
450   TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
451                          .Input({output_false, ema_value})
452                          .Finalize(graph, &assign_value));
453 
454   TF_RETURN_IF_ERROR(
455       NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
456           .Input(*var)
457           .Input(assign_value)
458           .Finalize(graph, var));
459   return Status::OK();
460 }
461 
462 // Computes the min and max EMA of input and stores them in min_var and max_var.
MakeEMAMinMaxVars(Graph * graph,const string & name_prefix,Node * input,std::vector<Node * > * added_variables,Node ** min_var,Node ** max_var)463 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
464                          std::vector<Node*>* added_variables, Node** min_var,
465                          Node** max_var) {
466   // TODO(suharshs): The decay will be constant, so we could make only one for
467   // all quantize_and_dequantize ops to share, this would have to live outside
468   // this function.
469   Tensor decay_tensor(DT_FLOAT, TensorShape());
470   decay_tensor.flat<float>()(0) = kEMADecay;
471   Node* decay;
472   TF_RETURN_IF_ERROR(
473       NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
474           .Attr("dtype", DT_FLOAT)
475           .Attr("value", decay_tensor)
476           .Finalize(graph, &decay));
477 
478   Node* reduction_axes;
479   TF_RETURN_IF_ERROR(
480       MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
481   Node* min;
482   string min_name = strings::StrCat(name_prefix, "/Min");
483   TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
484                          .Input(input)
485                          .Input(reduction_axes)
486                          .Finalize(graph, &min));
487   Node* max;
488   string max_name = strings::StrCat(name_prefix, "/Max");
489   TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
490                          .Input(input)
491                          .Input(reduction_axes)
492                          .Finalize(graph, &max));
493   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
494                                                 added_variables, min_var));
495   TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
496                                                 added_variables, max_var));
497   return Status::OK();
498 }
499 
500 // Makes an input min and max constant if the range is given. Otherwise, makes
501 // min and max variables that are updated by an EMA.
MakeInputMinMax(Graph * graph,const string & name_prefix,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** input_min,Node ** input_max)502 Status MakeInputMinMax(Graph* graph, const string& name_prefix,
503                        const EdgeToConvert& edge,
504                        std::vector<Node*>* added_variables, Node** input_min,
505                        Node** input_max) {
506   if (edge.range_given) {
507     // Make constant nodes for the input_min and input_max if the range is
508     // provided.
509     Tensor input_min_tensor(DT_FLOAT, TensorShape());
510     input_min_tensor.flat<float>()(0) = edge.input_min;
511     TF_RETURN_IF_ERROR(
512         NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
513             .Attr("dtype", DT_FLOAT)
514             .Attr("value", input_min_tensor)
515             .Finalize(graph, input_min));
516     Tensor input_max_tensor(DT_FLOAT, TensorShape());
517     input_max_tensor.flat<float>()(0) = edge.input_max;
518     TF_RETURN_IF_ERROR(
519         NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
520             .Attr("dtype", DT_FLOAT)
521             .Attr("value", input_max_tensor)
522             .Finalize(graph, input_max));
523   } else {
524     // If the range is not given, estimate the range with EMA variables.
525     TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
526                                          added_variables, input_min,
527                                          input_max));
528   }
529 
530   return Status::OK();
531 }
532 
533 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op
534 // (and required input nodes) based on edge.
535 // The result is stored in convert_node.
MakeQuantizeOp(Graph * graph,const string & name_prefix,const string & quant_op_type,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** convert_node)536 Status MakeQuantizeOp(Graph* graph, const string& name_prefix,
537                       const string& quant_op_type, const EdgeToConvert& edge,
538                       std::vector<Node*>* added_variables,
539                       Node** convert_node) {
540   Node* input_min;
541   Node* input_max;
542   TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
543                                      &input_min, &input_max));
544   string quant_name = strings::StrCat(name_prefix, "/", quant_op_type);
545   if (quant_op_type == "QuantizeAndDequantizeV2") {
546     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
547                            .Input(edge.edge->src())
548                            .Input(input_min)
549                            .Input(input_max)
550                            .Attr("signed_input", edge.signed_input)
551                            .Attr("num_bits", edge.num_bits)
552                            .Attr("range_given", true)
553                            .Finalize(graph, convert_node));
554   } else if (quant_op_type == "FakeQuantWithMinMaxVars") {
555     TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
556                            .Input(edge.edge->src())
557                            .Input(input_min)
558                            .Input(input_max)
559                            .Attr("num_bits", edge.num_bits)
560                            .Finalize(graph, convert_node));
561   } else {
562     return errors::InvalidArgument("Unknown quant op type: ", quant_op_type);
563   }
564   return Status::OK();
565 }
566 
567 // Insert conversion op, connect it to the graph and remove the old edge.
ProcessTargetEdges(Graph * graph,const string & quant_op_type,const std::vector<EdgeToConvert> & target_edges)568 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type,
569                           const std::vector<EdgeToConvert>& target_edges) {
570   // Remember previously converted ops to avoid duplicated conversion on the
571   // same input.
572   std::unordered_map<string, Node*, StringPieceHasher> name_index;
573   std::vector<Node*> added_variables;
574   for (const EdgeToConvert edge : target_edges) {
575     Node* convert_node;
576     string name_prefix = edge.edge->src()->name();
577 
578     auto iter = name_index.find(name_prefix);
579     if (iter == name_index.end()) {
580       TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge,
581                                         &added_variables, &convert_node));
582       name_index[name_prefix] = convert_node;
583     } else {
584       convert_node = iter->second;
585     }
586 
587     graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
588     graph->RemoveEdge(edge.edge);
589   }
590 
591   TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
592 
593   return Status::OK();
594 }
595 
596 }  // namespace
597 
DoQuantizeTraining(int32 num_bits,const string & quant_op_type,Graph * graph)598 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
599                           Graph* graph) {
600   if (graph == nullptr) {
601     return errors::InvalidArgument("Cannot accept empty graph pointer.");
602   }
603 
604   if (num_bits < 1 || num_bits > 63) {
605     return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
606                               num_bits);
607   }
608   int potential_input = 0;
609   std::vector<EdgeToConvert> target_edges;
610   for (Node* node : graph->nodes()) {
611     if (nodes_to_rewrite->find(node->type_string()) !=
612             nodes_to_rewrite->end() &&
613         !IsGradientNode(graph, node)) {
614       // Find out which types are the inputs and convert them accordingly.
615       // 1. Const/Variable OP: This is quantized as signed tensors with no given
616       // range.
617       // 2. Activation OP: Set the range accordingly for different types of
618       // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
619       // 3. Identity OP: The quantization parameters depend on its input.
620       // 4. Pooling OPs: various pooling ops. Also depends on its input.
621       // 5. Reshape OP: Also depends on the first input to this op.
622       // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
623       // model input. However, if there are >1 unknown ops, then returns an
624       // error for now to avoid unexpected behavior.
625       // Note: The list above might not be a complete list. Please let us
626       // know if you see the error so we can handle your case.
627       for (const Edge* edge : node->in_edges()) {
628         if (edge->src_output() == Graph::kControlSlot) {
629           // Skip the control dependency input.
630           continue;
631         } else {
632           bool signed_input = false;
633           bool range_given = false;
634           float input_min = 0;
635           float input_max = 0;
636           bool known_op = FindType(graph, edge->src(), &signed_input,
637                                    &range_given, &input_min, &input_max);
638           if (!known_op) {
639             // Unknown op is considered as input.
640             potential_input++;
641             if (potential_input > kAllowedInputs) {
642               return errors::Unimplemented(
643                   "Found an unknown op: ", edge->src()->name(),
644                   " with type: ", edge->src()->type_string(),
645                   "; Unknown ops are considered as model input for now and "
646                   "only ",
647                   kAllowedInputs, " inputs are supported currently.");
648             }
649           }
650 
651           target_edges.emplace_back(EdgeToConvert(
652               edge, num_bits, signed_input, range_given, input_min, input_max));
653         }
654       }
655     }
656   }
657 
658   TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges));
659 
660   return Status::OK();
661 }
662 
DoQuantizeTrainingOnGraphDef(const GraphDef & input_graphdef,int32 num_bits,const string & quant_op_type,GraphDef * result_graphdef)663 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
664                                     int32 num_bits, const string& quant_op_type,
665                                     GraphDef* result_graphdef) {
666   Graph graph(OpRegistry::Global());
667   GraphConstructorOptions opts;
668   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
669 
670   // Call the rewriter on the graph.
671   TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
672 
673   // Convert the result graph back to a GraphDef.
674   graph.ToGraphDef(result_graphdef);
675   return Status::OK();
676 }
677 
DoQuantizeTrainingOnSerializedGraphDef(const string & input_graph_string,int32 num_bits,const string & quant_op_type,string * result_graph_string)678 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
679                                               int32 num_bits,
680                                               const string& quant_op_type,
681                                               string* result_graph_string) {
682   // First create the graph from the GraphDef.
683   GraphDef input_graphdef;
684   if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
685     return errors::InvalidArgument(
686         "input_graph_string is not a serialized GraphDef protocol buffer");
687   }
688   GraphDef output_graphdef;
689   TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
690       input_graphdef, num_bits, quant_op_type, &output_graphdef));
691 
692   if (!output_graphdef.SerializeToString(result_graph_string)) {
693     return errors::Internal(
694         "quantize training transformation resulted in invalid GraphDef");
695   }
696   return Status::OK();
697 }
698 
699 }  // namespace tensorflow
700