1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <atomic>
18 #include <set>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "tensorflow/core/graph/quantize_training.h"
23
24 #include "tensorflow/core/common_runtime/executor.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/log_memory.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/graph/graph_constructor.h"
31 #include "tensorflow/core/graph/node_builder.h"
32 #include "tensorflow/core/graph/subgraph.h"
33 #include "tensorflow/core/lib/strings/strcat.h"
34 #include "tensorflow/core/public/session_options.h"
35
36 namespace tensorflow {
37 namespace {
38
39 // TODO(suharshs): If desired, make these values configurable.
40 const uint32 kAllowedInputs = 2;
41 const float kEMADecay = 0.999;
42
43 // Node types to rewrite. Insert quantize_and_dequantize op for their inputs.
44 const auto* nodes_to_rewrite =
45 new std::unordered_set<string, StringPieceHasher>{"MatMul", "Conv2D"};
46
47 // Contains necessary parameters to convert an edge.
48 struct EdgeToConvert {
49 // edge is not owned here.
50 const Edge* edge;
51 int32 num_bits;
52 bool signed_input;
53 bool range_given;
54 float input_min;
55 float input_max;
56
EdgeToConverttensorflow::__anonf797c4570111::EdgeToConvert57 EdgeToConvert(const Edge* e, int32 bits, bool sign, bool range, float min,
58 float max)
59 : edge(e),
60 num_bits(bits),
61 signed_input(sign),
62 range_given(range),
63 input_min(min),
64 input_max(max) {}
65 };
66
67 // Decide if a node is in backward pass by checking if its name is led by
68 // "gradients".
69 // TODO(jmchen): Make this check more robust as it is not guaranteed that the
70 // forward node will not be named with a leading "gradients".
IsGradientNode(const Graph * graph,const Node * node)71 inline bool IsGradientNode(const Graph* graph, const Node* node) {
72 static const string tag = "gradients";
73 return (node->name().compare(0, tag.size(), tag) == 0);
74 }
75
76 // Find the type of the input to set the parameters for the
77 // quantize_and_dequantize op.
78 // Returns true if the root tensor op type is known, false otherwise.
FindType(const Graph * graph,const Node * node,bool * signed_input,bool * range_given,float * input_min,float * input_max)79 bool FindType(const Graph* graph, const Node* node, bool* signed_input,
80 bool* range_given, float* input_min, float* input_max) {
81 const string& src_op = node->type_string();
82 if (src_op == "Const" || src_op == "Variable" || src_op == "VariableV2") {
83 *signed_input = true;
84 *range_given = false;
85 } else if (src_op == "Relu") {
86 // Range is not given for Relu.
87 *signed_input = false;
88 *range_given = false;
89 } else if (src_op == "Relu6") {
90 // TODO(suharshs): Also the theoretical min and max is 0 and 6, if the
91 // actual activations are somewhere in within this range, we can quantize
92 // this even further. This is true for other activations like Sigmoid6 too.
93 *signed_input = false;
94 *range_given = true;
95 *input_min = 0;
96 *input_max = 6;
97 } else if (src_op == "Sigmoid") {
98 *signed_input = false;
99 *range_given = true;
100 *input_min = 0;
101 *input_max = 1;
102 } else if (src_op == "Tanh") {
103 *signed_input = true;
104 *range_given = true;
105 *input_min = -1;
106 *input_max = 1;
107 } else if (src_op == "Reshape" || src_op == "ConcatV2") {
108 // Reshape has 2 inputs and the first one is the tensor.
109 // ConcatV2 has many inputs but they should all have the same activation
110 // function (i.e. Inception). So we just recurse on the first input.
111 for (const Edge* edge : node->in_edges()) {
112 if (edge->src_output() != Graph::kControlSlot && edge->dst_input() == 0) {
113 FindType(graph, edge->src(), signed_input, range_given, input_min,
114 input_max);
115 }
116 }
117 } else if (src_op == "Identity" || src_op == "MaxPool" ||
118 src_op == "AvgPool" || src_op == "MaxPool3D" ||
119 src_op == "AvgPool3D") {
120 // All these Ops only have 1 data input.
121 for (const Edge* edge : node->in_edges()) {
122 if (edge->src_output() != Graph::kControlSlot) {
123 FindType(graph, edge->src(), signed_input, range_given, input_min,
124 input_max);
125 }
126 }
127 } else {
128 // Unknown type, could be the model input examples.
129 // TODO(jmchen): Set the params for input with user's hint.
130 *signed_input = true;
131 *range_given = false;
132 return false;
133 }
134
135 return true;
136 }
137
138 // Find the Save op and inputs.
FindSaveOp(const Graph * graph,Node ** save_op,std::vector<const Edge * > * in_edges,bool * found)139 Status FindSaveOp(const Graph* graph, Node** save_op,
140 std::vector<const Edge*>* in_edges, bool* found) {
141 *found = false;
142 for (Node* node : graph->op_nodes()) {
143 if (node->type_string() == "SaveV2") {
144 // We found multiple save ops.
145 if (*found) {
146 return errors::InvalidArgument("Input graph has multiple SaveV2 ops.");
147 }
148 *save_op = node;
149 *found = true;
150 TF_RETURN_IF_ERROR(node->input_edges(in_edges));
151 }
152 }
153 return Status::OK();
154 }
155
FindRestoreAllOp(const Graph * graph,StringPiece save_prefix)156 Node* FindRestoreAllOp(const Graph* graph, StringPiece save_prefix) {
157 for (Node* node : graph->op_nodes()) {
158 // The restore_all op should have the same prefix of the save_op.
159 if (node->name() == strings::StrCat(save_prefix, "/restore_all")) {
160 return node;
161 }
162 }
163 return nullptr;
164 }
165
166 // Strips the last "/suffix" from a name.
167 // We use this to construct the name of restore ops in the same way they are
168 // constructed by the Saver.
GetNodeNamePrefix(const Node * node)169 StringPiece GetNodeNamePrefix(const Node* node) {
170 StringPiece name = node->name();
171 return name.substr(0, name.rfind('/'));
172 }
173
FillStringTensor(Tensor * dst,const Tensor & src)174 void FillStringTensor(Tensor* dst, const Tensor& src) {
175 auto dst_flat = dst->flat<string>();
176 auto src_flat = src.flat<string>();
177 for (int i = 0; i < src.NumElements(); i++) {
178 dst_flat(i) = src_flat(i);
179 }
180 }
181
182 // Add the added_variables as an inputs to the Save op.
183 // We change the inputs of the SaveV2 op to include the names of the added
184 // variables. We also add the variables as inputs to the save op.
ConnectVariablesToSaveOp(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & added_variables)185 Status ConnectVariablesToSaveOp(Graph* graph, Node* save_op,
186 const std::vector<const Edge*>& in_edges,
187 const std::vector<Node*>& added_variables) {
188 Node* tensor_names_op = in_edges[1]->src();
189 Node* shape_and_slices_op = in_edges[2]->src();
190
191 // Get the tensor_names and shape_and_slices tensors from the const op.
192 Tensor tensor_names;
193 Tensor shape_and_slices;
194 TF_RETURN_IF_ERROR(
195 GetNodeAttr(tensor_names_op->attrs(), "value", &tensor_names));
196 TF_RETURN_IF_ERROR(
197 GetNodeAttr(shape_and_slices_op->attrs(), "value", &shape_and_slices));
198
199 int tn_size = tensor_names.NumElements();
200 int var_size = added_variables.size();
201
202 // Create a new save_op that has inputs to all the new variables.
203 NodeBuilder save_op_builder =
204 NodeBuilder(save_op->name(), save_op->type_string());
205 // The first three inputs are prefix, tensor_names, and shapes_and_slices.
206 for (int i = 0; i < 3; i++) {
207 save_op_builder = save_op_builder.Input(in_edges[i]->src());
208 }
209 std::vector<NodeBuilder::NodeOut> var_nodeouts;
210 var_nodeouts.reserve(tn_size + var_size);
211 // The rest of the inputs need to be used the construct the tensor list arg.
212 for (int i = 3; i < in_edges.size(); i++) {
213 var_nodeouts.emplace_back(in_edges[i]->src());
214 }
215
216 // Add the new values to the tensors and the op input.
217 Tensor new_tensor_names(DT_STRING, TensorShape({tn_size + var_size}));
218 Tensor new_shape_and_slices(DT_STRING, TensorShape({tn_size + var_size}));
219 FillStringTensor(&new_tensor_names, tensor_names);
220 FillStringTensor(&new_shape_and_slices, shape_and_slices);
221 for (int i = 0; i < var_size; i++) {
222 Node* var = added_variables[i];
223 new_tensor_names.flat<string>()(tn_size + i) = var->name();
224 new_shape_and_slices.flat<string>()(tn_size + i) = "";
225 var_nodeouts.emplace_back(var);
226 }
227 save_op_builder = save_op_builder.Input(var_nodeouts);
228
229 // Update the attrs.
230 tensor_names_op->AddAttr("value", new_tensor_names);
231 shape_and_slices_op->AddAttr("value", new_shape_and_slices);
232
233 // Remove the old save_op and add the new one.
234 Node* new_save_op;
235 TF_RETURN_IF_ERROR(save_op_builder.Finalize(graph, &new_save_op));
236 // Add outputs to the new_save_op, all outputs are control edges.
237 for (const Edge* edge : save_op->out_edges()) {
238 graph->AddControlEdge(new_save_op, edge->dst());
239 }
240 graph->RemoveNode(save_op);
241
242 return Status::OK();
243 }
244
245 // Add a restore subgraph for each variable and connect to the restore_all op.
246 // For each variable we add the following subgraph:
247 // Assign----restore_all
248 // | |
249 // RestoreV2 Variable
AddRestoreVariableSubgraphs(Graph * graph,Node * save_op,const std::vector<const Edge * > & in_edges,const std::vector<Node * > & variables)250 Status AddRestoreVariableSubgraphs(Graph* graph, Node* save_op,
251 const std::vector<const Edge*>& in_edges,
252 const std::vector<Node*>& variables) {
253 Node* prefix_op = in_edges[0]->src();
254 StringPiece name_prefix = GetNodeNamePrefix(save_op);
255 Node* restore_all = FindRestoreAllOp(graph, name_prefix);
256 if (restore_all == nullptr) {
257 return errors::InvalidArgument("graph has SaveOp, but no restore_all NoOp");
258 }
259 const string restore_op_name = strings::StrCat(name_prefix, "/RestoreV2");
260 const string assign_op_name = strings::StrCat(name_prefix, "/Assign");
261 for (Node* var : variables) {
262 // Add an extra prefix after calling graph->NewName because the "unique"
263 // name may conflict with names generated for Send nodes.
264 // TODO(b/77547936): fix this more generally and get rid of the extra prefix
265 // here.
266 string new_restore_op_name =
267 strings::StrCat(graph->NewName(restore_op_name), "_qt");
268 string new_assign_op_name =
269 strings::StrCat(graph->NewName(assign_op_name), "_qt");
270 string tensor_names_op_name =
271 strings::StrCat(new_restore_op_name, "/tensor_names");
272 string shape_and_slices_op_name =
273 strings::StrCat(new_restore_op_name, "/shape_and_slices");
274
275 // Construct the tensor_names input with the variable name.
276 Node* tensor_names;
277 Tensor tensor_names_val(DT_STRING, TensorShape({1}));
278 tensor_names_val.flat<string>()(0) = var->name();
279 TF_RETURN_IF_ERROR(NodeBuilder(tensor_names_op_name, "Const")
280 .Attr("dtype", DT_STRING)
281 .Attr("value", tensor_names_val)
282 .Finalize(graph, &tensor_names));
283
284 // Construct the shape_and_slices input with empty string.
285 Node* shape_and_slices;
286 Tensor shape_and_slices_val(DT_STRING, TensorShape({1}));
287 shape_and_slices_val.flat<string>()(0) = "";
288 TF_RETURN_IF_ERROR(NodeBuilder(shape_and_slices_op_name, "Const")
289 .Attr("dtype", DT_STRING)
290 .Attr("value", shape_and_slices_val)
291 .Finalize(graph, &shape_and_slices));
292
293 // Build the new Restore op for this variable.
294 Node* restore_op;
295 TF_RETURN_IF_ERROR(NodeBuilder(new_restore_op_name, "RestoreV2")
296 .Input(prefix_op)
297 .Input(tensor_names)
298 .Input(shape_and_slices)
299 .Attr("dtypes", {DT_FLOAT})
300 .Finalize(graph, &restore_op));
301
302 // Create Assign op, attaching the variable and Restore op to it.
303 Node* assign_op;
304 TF_RETURN_IF_ERROR(NodeBuilder(new_assign_op_name, "Assign")
305 .Input(var)
306 .Input(restore_op)
307 .Finalize(graph, &assign_op));
308
309 // Add a control edge from the assign op to restore_all op.
310 graph->AddControlEdge(assign_op, restore_all);
311 }
312 return Status::OK();
313 }
314
315 // Adds new variables to save and restore ops matching the Save and Restore
316 // graphs created in tensorflow/python/training/saver.py.
AddSaveAndRestore(Graph * graph,const std::vector<Node * > & variables)317 Status AddSaveAndRestore(Graph* graph, const std::vector<Node*>& variables) {
318 Node* save_op = nullptr;
319 std::vector<const Edge*> in_edges;
320 bool found = false;
321 TF_RETURN_IF_ERROR(FindSaveOp(graph, &save_op, &in_edges, &found));
322 if (found) {
323 TF_RETURN_IF_ERROR(
324 AddRestoreVariableSubgraphs(graph, save_op, in_edges, variables));
325 TF_RETURN_IF_ERROR(
326 ConnectVariablesToSaveOp(graph, save_op, in_edges, variables));
327 }
328 return Status::OK();
329 }
330
331 // Sets output to the Node that computes reduction axes corresponding to all
332 // dimensions of input and return.
MakeReductionAxes(Graph * graph,string name_prefix,Node * input,Node ** output)333 Status MakeReductionAxes(Graph* graph, string name_prefix, Node* input,
334 Node** output) {
335 name_prefix = strings::StrCat(name_prefix, "/ReductionAxes");
336 Node* start;
337 Tensor zero_tensor(DT_INT32, TensorShape());
338 zero_tensor.flat<int32>()(0) = 0;
339 TF_RETURN_IF_ERROR(
340 NodeBuilder(strings::StrCat(name_prefix, "/RangeStart"), "Const")
341 .Attr("dtype", DT_INT32)
342 .Attr("value", zero_tensor)
343 .Finalize(graph, &start));
344 Node* delta;
345 Tensor one_tensor(DT_INT32, TensorShape());
346 one_tensor.flat<int32>()(0) = 1;
347 TF_RETURN_IF_ERROR(
348 NodeBuilder(strings::StrCat(name_prefix, "/RangeDelta"), "Const")
349 .Attr("dtype", DT_INT32)
350 .Attr("value", one_tensor)
351 .Finalize(graph, &delta));
352 Node* rank;
353 TF_RETURN_IF_ERROR(
354 NodeBuilder(strings::StrCat(name_prefix, "/InputRank"), "Rank")
355 .Input(input)
356 .Finalize(graph, &rank));
357 TF_RETURN_IF_ERROR(
358 NodeBuilder(strings::StrCat(name_prefix, "/ReductionAxes"), "Range")
359 .Input(start)
360 .Input(rank)
361 .Input(delta)
362 .Finalize(graph, output));
363 return Status::OK();
364 }
365
366 // Computes the exponential moving average of input, updated in update_variable.
MakeExponentialMovingAverage(Graph * graph,string name_prefix,const NodeBuilder::NodeOut & input,Node * decay,Node * update_variable,Node ** assign_value)367 Status MakeExponentialMovingAverage(Graph* graph, string name_prefix,
368 const NodeBuilder::NodeOut& input,
369 Node* decay, Node* update_variable,
370 Node** assign_value) {
371 // variable_t+1 = variable_t - [(variable_t - value) * (1 - decay)]
372 name_prefix = strings::StrCat(name_prefix, "/EMA");
373 Node* one;
374 Tensor one_tensor(DT_FLOAT, TensorShape());
375 one_tensor.flat<float>()(0) = 1.0;
376 TF_RETURN_IF_ERROR(
377 NodeBuilder(strings::StrCat(name_prefix, "/OneConst"), "Const")
378 .Attr("dtype", DT_FLOAT)
379 .Attr("value", one_tensor)
380 .Finalize(graph, &one));
381 Node* decay_complement;
382 TF_RETURN_IF_ERROR(
383 NodeBuilder(strings::StrCat(name_prefix, "/DecayComplement"), "Sub")
384 .Input(one)
385 .Input(decay)
386 .Finalize(graph, &decay_complement));
387
388 Node* value_diff;
389 TF_RETURN_IF_ERROR(
390 NodeBuilder(strings::StrCat(name_prefix, "/ValueDiff"), "Sub")
391 .Input(update_variable)
392 .Input(input)
393 .Finalize(graph, &value_diff));
394 Node* update_value;
395 TF_RETURN_IF_ERROR(
396 NodeBuilder(strings::StrCat(name_prefix, "/UpdateValue"), "Mul")
397 .Input(value_diff)
398 .Input(decay_complement)
399 .Finalize(graph, &update_value));
400
401 TF_RETURN_IF_ERROR(
402 NodeBuilder(strings::StrCat(name_prefix, "/EMAValue"), "Sub")
403 .Input(update_variable)
404 .Input(update_value)
405 .Finalize(graph, assign_value));
406 return Status::OK();
407 }
408
409 // Creates an automatically initialized exponential moving average variable.
410 // This uses a switch op to assign a value to the variable on the first run,
411 // and update with the moving average for all other runs:
412 // init_val
413 // |
414 // var--is_init--switch
415 // | true / \ false
416 // | | |
417 // | EMA init_val
418 // | \ /
419 // +----------- assign
MakeInitializedEMAVariable(Graph * graph,const string & name,Node * decay,Node * init_val,std::vector<Node * > * added_variables,Node ** var)420 Status MakeInitializedEMAVariable(Graph* graph, const string& name, Node* decay,
421 Node* init_val,
422 std::vector<Node*>* added_variables,
423 Node** var) {
424 // TODO(suharshs): Update this to use ResourceVariables when they are ready.
425 TF_RETURN_IF_ERROR(
426 NodeBuilder(strings::StrCat(name, "/Variable"), "VariableV2")
427 .Attr("shape", TensorShape())
428 .Attr("dtype", DT_FLOAT)
429 .Finalize(graph, var));
430 added_variables->push_back(*var);
431
432 Node* is_initialized;
433 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/IsInitialized"),
434 "IsVariableInitialized")
435 .Input(*var)
436 .Finalize(graph, &is_initialized));
437 Node* switch_node;
438 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Switch"), "Switch")
439 .Input(init_val)
440 .Input(is_initialized)
441 .Finalize(graph, &switch_node));
442 NodeBuilder::NodeOut output_false = NodeBuilder::NodeOut(switch_node, 0);
443 NodeBuilder::NodeOut output_true = NodeBuilder::NodeOut(switch_node, 1);
444
445 Node* ema_value;
446 TF_RETURN_IF_ERROR(MakeExponentialMovingAverage(graph, name, output_true,
447 decay, *var, &ema_value));
448
449 Node* assign_value;
450 TF_RETURN_IF_ERROR(NodeBuilder(strings::StrCat(name, "/Merge"), "Merge")
451 .Input({output_false, ema_value})
452 .Finalize(graph, &assign_value));
453
454 TF_RETURN_IF_ERROR(
455 NodeBuilder(strings::StrCat(name, "/AssignValue"), "Assign")
456 .Input(*var)
457 .Input(assign_value)
458 .Finalize(graph, var));
459 return Status::OK();
460 }
461
462 // Computes the min and max EMA of input and stores them in min_var and max_var.
MakeEMAMinMaxVars(Graph * graph,const string & name_prefix,Node * input,std::vector<Node * > * added_variables,Node ** min_var,Node ** max_var)463 Status MakeEMAMinMaxVars(Graph* graph, const string& name_prefix, Node* input,
464 std::vector<Node*>* added_variables, Node** min_var,
465 Node** max_var) {
466 // TODO(suharshs): The decay will be constant, so we could make only one for
467 // all quantize_and_dequantize ops to share, this would have to live outside
468 // this function.
469 Tensor decay_tensor(DT_FLOAT, TensorShape());
470 decay_tensor.flat<float>()(0) = kEMADecay;
471 Node* decay;
472 TF_RETURN_IF_ERROR(
473 NodeBuilder(strings::StrCat(name_prefix, "/Decay"), "Const")
474 .Attr("dtype", DT_FLOAT)
475 .Attr("value", decay_tensor)
476 .Finalize(graph, &decay));
477
478 Node* reduction_axes;
479 TF_RETURN_IF_ERROR(
480 MakeReductionAxes(graph, name_prefix, input, &reduction_axes));
481 Node* min;
482 string min_name = strings::StrCat(name_prefix, "/Min");
483 TF_RETURN_IF_ERROR(NodeBuilder(min_name, "Min")
484 .Input(input)
485 .Input(reduction_axes)
486 .Finalize(graph, &min));
487 Node* max;
488 string max_name = strings::StrCat(name_prefix, "/Max");
489 TF_RETURN_IF_ERROR(NodeBuilder(max_name, "Max")
490 .Input(input)
491 .Input(reduction_axes)
492 .Finalize(graph, &max));
493 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, min_name, decay, min,
494 added_variables, min_var));
495 TF_RETURN_IF_ERROR(MakeInitializedEMAVariable(graph, max_name, decay, max,
496 added_variables, max_var));
497 return Status::OK();
498 }
499
500 // Makes an input min and max constant if the range is given. Otherwise, makes
501 // min and max variables that are updated by an EMA.
MakeInputMinMax(Graph * graph,const string & name_prefix,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** input_min,Node ** input_max)502 Status MakeInputMinMax(Graph* graph, const string& name_prefix,
503 const EdgeToConvert& edge,
504 std::vector<Node*>* added_variables, Node** input_min,
505 Node** input_max) {
506 if (edge.range_given) {
507 // Make constant nodes for the input_min and input_max if the range is
508 // provided.
509 Tensor input_min_tensor(DT_FLOAT, TensorShape());
510 input_min_tensor.flat<float>()(0) = edge.input_min;
511 TF_RETURN_IF_ERROR(
512 NodeBuilder(strings::StrCat(name_prefix, "/InputMin"), "Const")
513 .Attr("dtype", DT_FLOAT)
514 .Attr("value", input_min_tensor)
515 .Finalize(graph, input_min));
516 Tensor input_max_tensor(DT_FLOAT, TensorShape());
517 input_max_tensor.flat<float>()(0) = edge.input_max;
518 TF_RETURN_IF_ERROR(
519 NodeBuilder(strings::StrCat(name_prefix, "/InputMax"), "Const")
520 .Attr("dtype", DT_FLOAT)
521 .Attr("value", input_max_tensor)
522 .Finalize(graph, input_max));
523 } else {
524 // If the range is not given, estimate the range with EMA variables.
525 TF_RETURN_IF_ERROR(MakeEMAMinMaxVars(graph, name_prefix, edge.edge->src(),
526 added_variables, input_min,
527 input_max));
528 }
529
530 return Status::OK();
531 }
532
533 // Adds a QuantizeAndDequantizeV2 or FakeQuantizeWithMinMaxVars op
534 // (and required input nodes) based on edge.
535 // The result is stored in convert_node.
MakeQuantizeOp(Graph * graph,const string & name_prefix,const string & quant_op_type,const EdgeToConvert & edge,std::vector<Node * > * added_variables,Node ** convert_node)536 Status MakeQuantizeOp(Graph* graph, const string& name_prefix,
537 const string& quant_op_type, const EdgeToConvert& edge,
538 std::vector<Node*>* added_variables,
539 Node** convert_node) {
540 Node* input_min;
541 Node* input_max;
542 TF_RETURN_IF_ERROR(MakeInputMinMax(graph, name_prefix, edge, added_variables,
543 &input_min, &input_max));
544 string quant_name = strings::StrCat(name_prefix, "/", quant_op_type);
545 if (quant_op_type == "QuantizeAndDequantizeV2") {
546 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
547 .Input(edge.edge->src())
548 .Input(input_min)
549 .Input(input_max)
550 .Attr("signed_input", edge.signed_input)
551 .Attr("num_bits", edge.num_bits)
552 .Attr("range_given", true)
553 .Finalize(graph, convert_node));
554 } else if (quant_op_type == "FakeQuantWithMinMaxVars") {
555 TF_RETURN_IF_ERROR(NodeBuilder(quant_name, quant_op_type)
556 .Input(edge.edge->src())
557 .Input(input_min)
558 .Input(input_max)
559 .Attr("num_bits", edge.num_bits)
560 .Finalize(graph, convert_node));
561 } else {
562 return errors::InvalidArgument("Unknown quant op type: ", quant_op_type);
563 }
564 return Status::OK();
565 }
566
567 // Insert conversion op, connect it to the graph and remove the old edge.
ProcessTargetEdges(Graph * graph,const string & quant_op_type,const std::vector<EdgeToConvert> & target_edges)568 Status ProcessTargetEdges(Graph* graph, const string& quant_op_type,
569 const std::vector<EdgeToConvert>& target_edges) {
570 // Remember previously converted ops to avoid duplicated conversion on the
571 // same input.
572 std::unordered_map<string, Node*, StringPieceHasher> name_index;
573 std::vector<Node*> added_variables;
574 for (const EdgeToConvert edge : target_edges) {
575 Node* convert_node;
576 string name_prefix = edge.edge->src()->name();
577
578 auto iter = name_index.find(name_prefix);
579 if (iter == name_index.end()) {
580 TF_RETURN_IF_ERROR(MakeQuantizeOp(graph, name_prefix, quant_op_type, edge,
581 &added_variables, &convert_node));
582 name_index[name_prefix] = convert_node;
583 } else {
584 convert_node = iter->second;
585 }
586
587 graph->AddEdge(convert_node, 0, edge.edge->dst(), edge.edge->dst_input());
588 graph->RemoveEdge(edge.edge);
589 }
590
591 TF_RETURN_IF_ERROR(AddSaveAndRestore(graph, added_variables));
592
593 return Status::OK();
594 }
595
596 } // namespace
597
DoQuantizeTraining(int32 num_bits,const string & quant_op_type,Graph * graph)598 Status DoQuantizeTraining(int32 num_bits, const string& quant_op_type,
599 Graph* graph) {
600 if (graph == nullptr) {
601 return errors::InvalidArgument("Cannot accept empty graph pointer.");
602 }
603
604 if (num_bits < 1 || num_bits > 63) {
605 return errors::OutOfRange("num_bits should be in range [1, 63] but is: ",
606 num_bits);
607 }
608 int potential_input = 0;
609 std::vector<EdgeToConvert> target_edges;
610 for (Node* node : graph->nodes()) {
611 if (nodes_to_rewrite->find(node->type_string()) !=
612 nodes_to_rewrite->end() &&
613 !IsGradientNode(graph, node)) {
614 // Find out which types are the inputs and convert them accordingly.
615 // 1. Const/Variable OP: This is quantized as signed tensors with no given
616 // range.
617 // 2. Activation OP: Set the range accordingly for different types of
618 // activations. Currently we handle {Relu, Relu6, Sigmoid, Tanh}
619 // 3. Identity OP: The quantization parameters depend on its input.
620 // 4. Pooling OPs: various pooling ops. Also depends on its input.
621 // 5. Reshape OP: Also depends on the first input to this op.
622 // 6. Not-Listed-Above OP: If there is only 1 such op, consider it as the
623 // model input. However, if there are >1 unknown ops, then returns an
624 // error for now to avoid unexpected behavior.
625 // Note: The list above might not be a complete list. Please let us
626 // know if you see the error so we can handle your case.
627 for (const Edge* edge : node->in_edges()) {
628 if (edge->src_output() == Graph::kControlSlot) {
629 // Skip the control dependency input.
630 continue;
631 } else {
632 bool signed_input = false;
633 bool range_given = false;
634 float input_min = 0;
635 float input_max = 0;
636 bool known_op = FindType(graph, edge->src(), &signed_input,
637 &range_given, &input_min, &input_max);
638 if (!known_op) {
639 // Unknown op is considered as input.
640 potential_input++;
641 if (potential_input > kAllowedInputs) {
642 return errors::Unimplemented(
643 "Found an unknown op: ", edge->src()->name(),
644 " with type: ", edge->src()->type_string(),
645 "; Unknown ops are considered as model input for now and "
646 "only ",
647 kAllowedInputs, " inputs are supported currently.");
648 }
649 }
650
651 target_edges.emplace_back(EdgeToConvert(
652 edge, num_bits, signed_input, range_given, input_min, input_max));
653 }
654 }
655 }
656 }
657
658 TF_RETURN_IF_ERROR(ProcessTargetEdges(graph, quant_op_type, target_edges));
659
660 return Status::OK();
661 }
662
DoQuantizeTrainingOnGraphDef(const GraphDef & input_graphdef,int32 num_bits,const string & quant_op_type,GraphDef * result_graphdef)663 Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
664 int32 num_bits, const string& quant_op_type,
665 GraphDef* result_graphdef) {
666 Graph graph(OpRegistry::Global());
667 GraphConstructorOptions opts;
668 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, input_graphdef, &graph));
669
670 // Call the rewriter on the graph.
671 TF_RETURN_IF_ERROR(DoQuantizeTraining(num_bits, quant_op_type, &graph));
672
673 // Convert the result graph back to a GraphDef.
674 graph.ToGraphDef(result_graphdef);
675 return Status::OK();
676 }
677
DoQuantizeTrainingOnSerializedGraphDef(const string & input_graph_string,int32 num_bits,const string & quant_op_type,string * result_graph_string)678 Status DoQuantizeTrainingOnSerializedGraphDef(const string& input_graph_string,
679 int32 num_bits,
680 const string& quant_op_type,
681 string* result_graph_string) {
682 // First create the graph from the GraphDef.
683 GraphDef input_graphdef;
684 if (!ParseProtoUnlimited(&input_graphdef, input_graph_string)) {
685 return errors::InvalidArgument(
686 "input_graph_string is not a serialized GraphDef protocol buffer");
687 }
688 GraphDef output_graphdef;
689 TF_RETURN_IF_ERROR(DoQuantizeTrainingOnGraphDef(
690 input_graphdef, num_bits, quant_op_type, &output_graphdef));
691
692 if (!output_graphdef.SerializeToString(result_graph_string)) {
693 return errors::Internal(
694 "quantize training transformation resulted in invalid GraphDef");
695 }
696 return Status::OK();
697 }
698
699 } // namespace tensorflow
700