• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19 
20 #include <cmath>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57 
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61 
62 // We only fold/materialize constants smaller than 100kB.
63 const int64 kMaxConstantSize = 100 * 1024;
64 
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68   Tensor tensor;
69   if (!tensor.FromProto(proto)) {
70     return false;
71   }
72   auto values = tensor.flat<T>();
73   for (int i = 0; i < tensor.NumElements(); ++i) {
74     if (values(i) != value) {
75       return false;
76     }
77   }
78   return true;
79 }
80 
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85                           GraphDef* graph, NodeMap* node_map) {
86   bool already_exists = false;
87   for (const string& input : node->input()) {
88     if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89       already_exists = true;
90       break;
91     }
92   }
93   if (!already_exists) {
94     const string ctrl_dep =
95         ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96     node->add_input(ctrl_dep);
97     node_map->AddOutput(NodeName(ctrl_input), node->name());
98   }
99   return !already_exists;
100 }
101 
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104                              GraphDef* graph, NodeMap* node_map) {
105   bool removed_input = false;
106   bool update_node_map = true;
107   const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108   for (int i = 0; i < node->input_size(); ++i) {
109     const string& input = node->input(i);
110     if (old_input_ctrl_dep == input) {
111       if (IsControlInput(input)) {
112         node->mutable_input()->SwapElements(i, node->input_size() - 1);
113         node->mutable_input()->RemoveLast();
114         removed_input = true;
115       } else {
116         // There is a non-control input from the same node.
117         // Don't remove the output from the NodeMap.
118         update_node_map = false;
119       }
120     }
121   }
122   if (update_node_map) {
123     node_map->RemoveOutput(NodeName(old_input), node->name());
124   }
125   return removed_input;
126 }
127 
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129   AttrSlice attrs(node);
130   for (const auto& attr : attrs) {
131     if (attr.first.find("_tpu_") != attr.first.npos) {
132       return true;
133     }
134   }
135   return false;
136 }
137 
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140   return a != b;
141 }
142 
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147 
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152 
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154   switch (data_type) {
155     case DT_QINT8:
156       return Eigen::NumTraits<qint8>::lowest();
157     case DT_QUINT8:
158       return Eigen::NumTraits<quint8>::lowest();
159     case DT_QINT16:
160       return Eigen::NumTraits<qint16>::lowest();
161     case DT_QUINT16:
162       return Eigen::NumTraits<quint16>::lowest();
163     case DT_QINT32:
164       return Eigen::NumTraits<qint32>::lowest();
165     default:
166       return 0.0f;
167   }
168 }
169 
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171   switch (data_type) {
172     case DT_QINT8:
173       return Eigen::NumTraits<qint8>::highest();
174     case DT_QUINT8:
175       return Eigen::NumTraits<quint8>::highest();
176     case DT_QINT16:
177       return Eigen::NumTraits<qint16>::highest();
178     case DT_QUINT16:
179       return Eigen::NumTraits<quint16>::highest();
180     case DT_QINT32:
181       return Eigen::NumTraits<qint32>::highest();
182     default:
183       return 0.0f;
184   }
185 }
186 
187 }  // namespace
188 
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190                                  DeviceBase* cpu_device,
191                                  bool disable_compressed_tensor_optimization,
192                                  bool fold_quantization_emulation)
193     : opt_level_(opt_level),
194       cpu_device_(cpu_device),
195       disable_compressed_tensor_optimization_(
196           disable_compressed_tensor_optimization),
197       fold_quantization_emulation_(fold_quantization_emulation) {
198   resource_mgr_.reset(new ResourceMgr());
199 }
200 
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202                                  bool disable_compressed_tensor_optimization,
203                                  bool fold_quantization_ops)
204     : ConstantFolding(RewriterConfig::ON, cpu_device,
205                       disable_compressed_tensor_optimization,
206                       fold_quantization_ops) {}
207 
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210                                              GraphDef* graph,
211                                              NodeMap* node_map) {
212   if (IsControlInput(input_name)) {
213     return input_name;
214   }
215   const NodeDef& node = *node_map->GetNode(input_name);
216   if (!IsSwitch(node)) {
217     return AsControlDependency(node);
218   } else {
219     // We can't anchor control dependencies directly on the switch node: unlike
220     // other nodes only one of the outputs of the switch node will be generated
221     // when the switch node is executed, and we need to make sure the control
222     // dependency is only triggered when the corresponding output is triggered.
223     // We start by looking for an identity node connected to the output of the
224     // switch node, and use it to anchor the control dependency.
225     for (const NodeDef* output : node_map->GetOutputs(node.name())) {
226       if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
227         if (IsSameInput(node.input(0), input_name)) {
228           return AsControlDependency(*output);
229         }
230       }
231     }
232     // We haven't found an existing node where we can anchor the control
233     // dependency: add a new identity node.
234     int port = 0;
235     string ctrl_dep_name = ParseNodeName(input_name, &port);
236     strings::StrAppend(&ctrl_dep_name, "_", port);
237     ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
238     const DataType output_type = node.attr().at("T").type();
239 
240     NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
241     if (added_node == nullptr) {
242       added_node = graph->add_node();
243       added_node->set_name(ctrl_dep_name);
244       added_node->set_op("Identity");
245       added_node->set_device(node.device());
246 
247       (*added_node->mutable_attr())["T"].set_type(output_type);
248       *added_node->add_input() = input_name;
249       node_map->AddNode(added_node->name(), added_node);
250       node_map->AddOutput(node.name(), added_node->name());
251     }
252     return AsControlDependency(*added_node);
253   }
254 }
255 
256 // Forward inputs at the given indices to outputs and add a control dependency
257 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)258 bool ConstantFolding::ForwardInputs(NodeDef* node,
259                                     absl::Span<const int> inputs_to_forward) {
260   for (int input_idx : inputs_to_forward) {
261     if (input_idx < 0 || input_idx >= node->input_size()) {
262       return false;
263     }
264   }
265 
266   const auto& tmp = node_map_->GetOutputs(node->name());
267   const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
268   bool updated_graph = false;
269   for (int input_idx : inputs_to_forward) {
270     const string& input = node->input(input_idx);
271     if (IsControlInput(input) && consumers.size() > 1) {
272       continue;
273     }
274     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
275     if (input_node == nullptr) {
276       LOG(ERROR) << "Bad input: " << input;
277       break;
278     }
279     // Update each consumer.
280     for (NodeDef* consumer : consumers) {
281       bool add_dep = false;
282       for (int consumer_input_idx = 0;
283            consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
284         const string& consumer_input = consumer->input(consumer_input_idx);
285         if (IsControlInput(consumer_input)) {
286           break;
287         }
288         // It is illegal to add control dependencies to _Retval nodes, so we
289         // can't bypass value producing `node` and forward inputs to `consumer`.
290         if (IsRetval(*consumer)) {
291           break;
292         }
293         int output_idx;
294         const string input_node_name =
295             ParseNodeName(consumer_input, &output_idx);
296         if (input_node_name == node->name() && output_idx == input_idx) {
297           consumer->set_input(consumer_input_idx, input);
298           // We will keep the input from the node through a control
299           // dependency, so we only need to add the consumer as an output
300           // for the input node.
301           node_map_->AddOutput(NodeName(input), consumer->name());
302           add_dep = true;
303         }
304       }
305       if (add_dep) {
306         consumer->add_input(AsControlDependency(node->name()));
307         updated_graph = true;
308       }
309     }
310   }
311 
312   if (updated_graph) {
313     for (NodeDef* consumer : consumers) {
314       DedupControlInputs(consumer);
315     }
316   }
317   return updated_graph;
318 }
319 
320 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64 value,const DataType & type,const int index,Tensor * tensor)321 static Status PutValueIntoTensor(const int64 value, const DataType& type,
322                                  const int index, Tensor* tensor) {
323   if (type == DT_INT32) {
324     if (value >= INT_MAX) {
325       return Status(error::INVALID_ARGUMENT, "int32 overflow");
326     }
327     tensor->flat<int32>()(index) = static_cast<int32>(value);
328   } else {
329     tensor->flat<int64>()(index) = value;
330   }
331   return Status::OK();
332 }
333 
334 // Writes the given tensor shape into the given tensor.
335 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)336 static Status ConvertShapeToConstant(const string& op, const DataType& type,
337                                      const PartialTensorShape& shp,
338                                      Tensor* tensor) {
339   if (op == "Shape" || op == "ShapeN") {
340     *tensor = Tensor(type, TensorShape({shp.dims()}));
341     for (int i = 0; i < shp.dims(); ++i) {
342       TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
343     }
344   } else if (op == "Size") {
345     int64 size = 1;
346     for (int i = 0; i < shp.dims(); ++i) {
347       size *= shp.dim_size(i);
348     }
349     *tensor = Tensor(type, TensorShape({}));
350     TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
351   } else {
352     CHECK_EQ(op, "Rank");
353     *tensor = Tensor(type, TensorShape({}));
354     TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
355   }
356   return Status::OK();
357 }
358 
359 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const360 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
361                                           StringPiece suffix) const {
362   return node_map_->NodeExists(OptimizedNodeName(node, suffix));
363 }
364 
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const365 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
366                                           StringPiece suffix) const {
367   return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
368                              kConstantFoldingConst);
369 }
370 
IsReallyConstant(const NodeDef & node) const371 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
372   if (!IsConstant(node)) {
373     return false;
374   }
375   // If the node is fed it's not constant anymore.
376   return feed_nodes_.find(node.name()) == feed_nodes_.end();
377 }
378 
379 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)380 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
381                                              Tensor* tensor) {
382   const NodeDef* node = node_map_->GetNode(node_name_or_input);
383   return node != nullptr && IsReallyConstant(*node) &&
384          CheckAttrExists(*node, "value").ok() &&
385          tensor->FromProto(node->attr().at("value").tensor());
386 }
387 
388 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)389 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
390   // We may add some nodes to the graph to encode control dependencies and hold
391   // the materialized shapes: there is no need to process these added nodes, so
392   // only iterate over the nodes of the input graph.
393   const int node_count = graph_->node_size();
394   for (int node_idx = 0; node_idx < node_count; ++node_idx) {
395     NodeDef* node = graph_->mutable_node(node_idx);
396     const string op = node->op();
397     if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
398         op != "TensorArraySizeV3") {
399       continue;
400     }
401     const std::vector<OpInfo::TensorProperties>& output =
402         properties.GetOutputProperties(node->name());
403     const std::vector<OpInfo::TensorProperties>& input =
404         properties.GetInputProperties(node->name());
405     if (input.empty() || output.empty()) {
406       continue;
407     }
408 
409     if (op == "Shape" || op == "Size" || op == "Rank") {
410       CHECK_EQ(1, output.size());
411       CHECK_EQ(1, input.size());
412 
413       const DataType type = output[0].dtype();
414       CHECK(type == DT_INT32 || type == DT_INT64);
415       const PartialTensorShape shape(input[0].shape());
416 
417       if ((op != "Rank" && !shape.IsFullyDefined()) ||
418           (op == "Rank" && shape.unknown_rank())) {
419         continue;
420       }
421 
422       Tensor constant_value(type);
423       if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
424         continue;
425       }
426 
427       // TODO(rmlarsen): Remove this workaround for b/150861569
428       // The bug involves an expression of the form Shape(ExpandDims(x)
429       // with an incorrectly inferred zero-size first dimension.
430       if (op == "Shape") {
431         if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
432       }
433 
434       // Repurpose the existing node to be the constant.
435       // Device placement is preserved.
436       graph_modified_ = true;
437       node->set_op("Const");
438       EraseRegularNodeAttributes(node);
439       (*node->mutable_attr())["dtype"].set_type(type);
440       constant_value.AsProtoTensorContent(
441           (*node->mutable_attr())["value"].mutable_tensor());
442 
443       // Turn the data input into a control dependency: this is needed to
444       // ensure that the constant value will only be run in the
445       // cases where the shape/rank/size would have been run in
446       // the original graph.
447       string ctrl_dep =
448           AddControlDependency(node->input(0), graph_, node_map_.get());
449       node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
450       node->set_input(0, ctrl_dep);
451       // Done with the Shape/Size/Rank node, move to the next node.
452       continue;
453     }
454 
455     if (op == "TensorArraySizeV3") {
456       const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
457       if (array->input_size() == 0 ||
458           (array->attr().count("dynamic_size") != 0 &&
459            array->attr().at("dynamic_size").b())) {
460         continue;
461       }
462       const NodeDef* array_size =
463           CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
464       if (IsReallyConstant(*array_size)) {
465         // Don't materialize 0 sizes to avoid triggering incorrect static
466         // checks. A 0 sized array that can't grow isn't useful anyway.
467         if (array_size->attr().count("value") == 0) {
468           continue;
469         }
470         const TensorProto& raw_val = array_size->attr().at("value").tensor();
471         if (raw_val.dtype() != DT_INT32) {
472           continue;
473         }
474         Tensor value(raw_val.dtype(), raw_val.tensor_shape());
475         if (!value.FromProto(raw_val)) {
476           continue;
477         }
478         if (value.flat<int32>()(0) == 0) {
479           continue;
480         }
481 
482         graph_modified_ = true;
483         node->set_op("Const");
484         *node->mutable_attr() = array_size->attr();
485         node->set_input(0, AsControlDependency(NodeName(node->input(0))));
486         node->set_input(1, AddControlDependency(NodeName(node->input(1)),
487                                                 graph_, node_map_.get()));
488       }
489       continue;
490     }
491 
492     // Handle ShapeN materialization case.
493     // It's possible that not all input tensors have known shapes.
494     CHECK_EQ(op, "ShapeN");
495     CHECK_EQ(input.size(), output.size());
496     const NodeDef* const shape_n_node = node;
497     for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
498          ++port_idx) {
499       const DataType type = output[port_idx].dtype();
500       CHECK(type == DT_INT32 || type == DT_INT64);
501       const PartialTensorShape shape(input[port_idx].shape());
502       if (!shape.IsFullyDefined()) {
503         continue;
504       }
505       Tensor constant_value(type);
506       auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
507       if (!status.ok()) {
508         continue;
509       }
510 
511       // We make a copy because we mutate the nodes.
512       auto fanouts = node_map_->GetOutputs(shape_n_node->name());
513       // Find all nodes consuming this shape and connect them through the new
514       // constant node instead.
515       for (NodeDef* output : fanouts) {
516         // Track whether there are any direct edges left between shape_n_node
517         // and this output node after the transformation.
518         bool direct_edges_exist = false;
519         for (int k = 0; k < output->input_size(); ++k) {
520           int port;
521           const string node_name = ParseNodeName(output->input(k), &port);
522           if (node_name == shape_n_node->name() && port == port_idx) {
523             // Create a const node as ShapeN's output if not already.
524             const string const_name = OptimizedNodeName(
525                 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
526             if (node_map_->GetNode(const_name) == nullptr) {
527               NodeDef* added_node = graph_->add_node();
528               added_node->set_name(const_name);
529               added_node->set_op("Const");
530               added_node->set_device(shape_n_node->device());
531               node_map_->AddNode(added_node->name(), added_node);
532               (*added_node->mutable_attr())["dtype"].set_type(type);
533               constant_value.AsProtoTensorContent(
534                   (*added_node->mutable_attr())["value"].mutable_tensor());
535               // We add a control dependency to the original ShapeN node,
536               // so that the node will only be run if all inputs of the
537               // original ShapeN node are run.
538               string ctrl_dep = AddControlDependency(shape_n_node->name(),
539                                                      graph_, node_map_.get());
540               *added_node->add_input() = ctrl_dep;
541               node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
542             }
543             *output->mutable_input(k) = const_name;
544             node_map_->AddOutput(const_name, output->name());
545             graph_modified_ = true;
546           }
547           if (node_name == shape_n_node->name() && port != port_idx) {
548             direct_edges_exist = true;
549           }
550         }
551         if (!direct_edges_exist) {
552           node_map_->RemoveOutput(node->name(), output->name());
553         }
554       }
555     }
556   }
557 
558   return Status::OK();
559 }
560 
561 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)562 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
563                   BCast::Vec* shape, int64* min_id) {
564   if (shape_node.op() == "Shape") {
565     const std::vector<OpInfo::TensorProperties>& prop1 =
566         properties.GetInputProperties(shape_node.name());
567     if (prop1.size() != 1) {
568       return false;
569     }
570     const TensorShapeProto& shp = prop1[0].shape();
571     if (shp.unknown_rank()) {
572       return false;
573     }
574     for (const auto& dim : shp.dim()) {
575       shape->push_back(dim.size());
576       *min_id = std::min<int64>(*min_id, dim.size());
577     }
578   } else {
579     if (shape_node.attr().count("value") == 0) {
580       return false;
581     }
582     const TensorProto& raw_val = shape_node.attr().at("value").tensor();
583     if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
584       return false;
585     }
586     Tensor value(raw_val.dtype(), raw_val.tensor_shape());
587     if (!value.FromProto(raw_val)) {
588       return false;
589     }
590     for (int j = 0; j < value.NumElements(); ++j) {
591       if (raw_val.dtype() == DT_INT64) {
592         shape->push_back(value.vec<int64>()(j));
593       } else {
594         shape->push_back(value.vec<int>()(j));
595       }
596     }
597   }
598   return true;
599 }
600 }  // namespace
601 
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)602 Status ConstantFolding::MaterializeBroadcastGradientArgs(
603     const NodeDef& node, const GraphProperties& properties) {
604   const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
605   const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
606   if (shape_node1 == nullptr ||
607       (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
608       shape_node2 == nullptr ||
609       (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
610     return Status::OK();
611   }
612 
613   // Don't optimize this again if it was already optimized and folded.
614   if (OptimizedNodeExists(node, "-folded-1") ||
615       OptimizedNodeExists(node, "-folded-2")) {
616     return Status::OK();
617   }
618   int64 min_id = 0;
619   BCast::Vec shape1;
620   if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
621     return Status::OK();
622   }
623   BCast::Vec shape2;
624   if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
625     return Status::OK();
626   }
627   // A value of -1 means we don't known anything about the dimension. Replace
628   // the -1 values with unique dimension ids since we don't want two '-1'
629   // dimensions to be considered equal.
630   for (auto& id : shape1) {
631     if (id == -1) {
632       id = --min_id;
633     }
634   }
635   for (auto& id : shape2) {
636     if (id == -1) {
637       id = --min_id;
638     }
639   }
640 
641   // Beware: the reduction dimensions computed by the BCast class are valid iff
642   // we assume that two distinct symbolic dimensions can't be equal and a
643   // symbolic dimension can't be equal to 1. This is often but not always true,
644   // so to make this optimization safe we filter out these cases.
645   const int common_dims = std::min(shape1.size(), shape2.size());
646   for (int i = 0; i < common_dims; ++i) {
647     if (shape1[i] >= 0 && shape2[i] >= 0) {
648       continue;
649     }
650     if (shape1[i] != shape2[i]) {
651       // We're either dealing with 2 different symbolic dimensions or a symbolic
652       // and a know dimensions. We can't be sure whether both are equal or not,
653       // so we can't be sure whether we'll be broadcasting or not.
654       return Status::OK();
655     }
656   }
657   // These extra dims could be equal to 1, in which case there is no
658   // broadcasting. It could also be greater than 1, in which case there would
659   // be broadcasting. Since we don't know, we'll just punt.
660   for (int i = common_dims, end = shape1.size(); i < end; ++i) {
661     if (shape1[i] < 0) {
662       return Status::OK();
663     }
664   }
665   for (int i = common_dims, end = shape2.size(); i < end; ++i) {
666     if (shape2[i] < 0) {
667       return Status::OK();
668     }
669   }
670 
671   BCast bcast(shape1, shape2);
672   if (!bcast.IsValid()) {
673     return Status::OK();
674   }
675 
676   BCast::Vec reduce_dims[2];
677   reduce_dims[0] = bcast.grad_x_reduce_idx();
678   reduce_dims[1] = bcast.grad_y_reduce_idx();
679 
680   TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
681   const DataType type = node.attr().at("T").type();
682   NodeDef* out[2];
683   for (int j = 0; j < 2; ++j) {
684     int reduction_indices = reduce_dims[j].size();
685     Tensor value(type, TensorShape({reduction_indices}));
686     for (int i = 0; i < reduction_indices; ++i) {
687       if (type == DT_INT32) {
688         value.vec<int32>()(i) = reduce_dims[j][i];
689       } else {
690         value.vec<int64>()(i) = reduce_dims[j][i];
691       }
692     }
693     string const_name =
694         OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
695     out[j] = node_map_->GetNode(const_name);
696     if (out[j] == nullptr) {
697       out[j] = graph_->add_node();
698       TF_RETURN_IF_ERROR(
699           CreateNodeDef(const_name, TensorValue(&value), out[j]));
700       out[j]->set_device(node.device());
701       node_map_->AddNode(const_name, out[j]);
702       string ctrl_dep =
703           AddControlDependency(node.name(), graph_, node_map_.get());
704       *out[j]->add_input() = ctrl_dep;
705       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
706     }
707   }
708 
709   // We make a copy here since we might mutate the set.
710   const auto outputs = node_map_->GetOutputs(node.name());
711   for (NodeDef* output : outputs) {
712     for (int k = 0; k < output->input_size(); ++k) {
713       int port;
714       string node_name = ParseNodeName(output->input(k), &port);
715       if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
716         *output->mutable_input(k) = out[port]->name();
717         node_map_->UpdateInput(output->name(), node_name, out[port]->name());
718       }
719     }
720   }
721 
722   return Status::OK();
723 }
724 
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)725 Status ConstantFolding::MaterializeReductionIndices(
726     NodeDef* node, const GraphProperties& properties) {
727   if (node->input_size() < 2) {
728     return Status::OK();
729   }
730   const NodeDef* indices = node_map_->GetNode(node->input(1));
731   if (!indices || IsReallyConstant(*indices)) {
732     // The reduction indices are already constant, there's nothing to do.
733     return Status::OK();
734   }
735 
736   const std::vector<OpInfo::TensorProperties>& input_props =
737       properties.GetInputProperties(node->name());
738   if (input_props.size() != 2) {
739     return Status::OK();
740   }
741   const OpInfo::TensorProperties& input_prop = input_props[0];
742   if (input_prop.shape().unknown_rank()) {
743     // We can't do anything if we don't know the rank of the input.
744     return Status::OK();
745   }
746   const int input_rank = input_prop.shape().dim_size();
747   if (input_rank < 1) {
748     // Unexpected graph, don't try to change it.
749     return Status::OK();
750   }
751   const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
752   DataType dtype = reduction_indices_prop.dtype();
753   if (dtype != DT_INT32 && dtype != DT_INT64) {
754     return Status::OK();
755   }
756   PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
757   const int num_reduction_indices = reduction_indices_shape.num_elements();
758 
759   const std::vector<OpInfo::TensorProperties>& output_props =
760       properties.GetOutputProperties(node->name());
761   if (output_props.size() != 1) {
762     return Status::OK();
763   }
764   const OpInfo::TensorProperties& output_prop = output_props[0];
765   const int output_rank =
766       output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
767 
768   bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
769   if (!full_reduction) {
770     // A full reduction will generate a tensor of one of the shapes
771     // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
772     // elements in the output of the reduction, we may deduce it from reshape
773     // nodes following it.
774     for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
775       full_reduction = false;
776       if (!IsReshape(*fanout)) {
777         return Status::OK();
778       }
779       const std::vector<OpInfo::TensorProperties>& reshape_props =
780           properties.GetOutputProperties(fanout->name());
781       if (reshape_props.size() != 1) {
782         return Status::OK();
783       }
784       const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
785       PartialTensorShape shape(reshape_prop.shape());
786       if (shape.num_elements() != 1) {
787         return Status::OK();
788       } else {
789         full_reduction = true;
790       }
791     }
792     if (!full_reduction) {
793       return Status::OK();
794     }
795   }
796 
797   // We know it's a full reduction. We can generate the full set of indices to
798   // reduce as a constant node.
799   string const_name = OptimizedNodeName(*node, "-reduction_indices");
800   if (node_map_->GetNode(const_name)) {
801     return Status::OK();
802   }
803   NodeDef* reduction_indices = graph_->add_node();
804   Tensor value(dtype, TensorShape({input_rank}));
805   for (int i = 0; i < input_rank; ++i) {
806     if (dtype == DT_INT32) {
807       value.vec<int32>()(i) = i;
808     } else {
809       value.vec<int64>()(i) = i;
810     }
811   }
812   TF_RETURN_IF_ERROR(
813       CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
814 
815   reduction_indices->set_device(node->device());
816   string ctrl_dep =
817       AddControlDependency(node->input(1), graph_, node_map_.get());
818   *reduction_indices->add_input() = ctrl_dep;
819   node_map_->AddNode(const_name, reduction_indices);
820   node_map_->AddOutput(NodeName(ctrl_dep), const_name);
821 
822   node->set_input(1, reduction_indices->name());
823   node_map_->UpdateInput(node->name(), indices->name(),
824                          reduction_indices->name());
825 
826   return Status::OK();
827 }
828 
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)829 Status ConstantFolding::MaterializeConstantValuedNode(
830     NodeDef* node, const GraphProperties& properties) {
831   if (disable_compressed_tensor_optimization_) {
832     return Status::OK();
833   }
834   // Nodes that generate constant-valued outputs can be represented compactly in
835   // compressed format, regardless of their shape.
836   const std::vector<OpInfo::TensorProperties>& output_props =
837       properties.GetOutputProperties(node->name());
838   if (output_props.size() != 1) return Status::OK();
839   const auto& output_shape = output_props[0].shape();
840   if (!PartialTensorShape(output_shape).IsFullyDefined()) {
841     return Status::OK();
842   }
843   if (IsFill(*node)) {
844     const auto output_dtype = output_props[0].dtype();
845     NodeDef* input_node = nullptr;
846     for (int i = 0; i < 2; ++i) {
847       input_node = node_map_->GetNode(NodeName(node->input(i)));
848       if (input_node == nullptr || !IsReallyConstant(*input_node)) {
849         return Status::OK();
850       }
851     }
852     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
853 
854     // Copy the input tensor to the fill node, set the output shape and data
855     // type, and change the node type to Const.
856     TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
857     const TensorProto& input_tensor = input_node->attr().at("value").tensor();
858     if (!input_tensor.tensor_content().empty()) {
859       // Convert the value to repeated field format, so we can use the
860       // decompression mechanism to store only a single value in the constant
861       // node, even if the shape specified in the original Fill is large.
862       Tensor t;
863       if (!t.FromProto(input_tensor)) {
864         return errors::InvalidArgument(
865             "Could not construct Tensor form TensorProto in node: ",
866             input_node->name());
867       }
868       tensor->clear_tensor_content();
869       t.AsProtoField(tensor);
870     } else {
871       *tensor = input_tensor;
872     }
873     *(tensor->mutable_tensor_shape()) = output_shape;
874     (*node->mutable_attr())["dtype"].set_type(output_dtype);
875     node->mutable_attr()->erase("T");
876     node->mutable_attr()->erase("index_type");
877     node->set_op("Const");
878     for (int i = 0; i < 2; i++) {
879       // Change inputs to a control inputs.
880       const string ctrl_dep = AsControlDependency(node->input(i));
881       node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
882       node->set_input(i, ctrl_dep);
883     }
884     graph_modified_ = true;
885   } else {
886     double value =
887         (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
888     if (value >= 0) {
889       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
890           value, properties, output_shape, node, graph_));
891     }
892   }
893   return Status::OK();
894 }
895 
896 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)897 Status ConstantFolding::MaterializeOutputValues(
898     NodeDef* node, const GraphProperties& properties) {
899   const std::vector<OpInfo::TensorProperties>& output =
900       properties.GetOutputProperties(node->name());
901   if (output.size() != 1 || !output[0].has_value() ||
902       !IsFoldable(*node, &properties)) {
903     return Status::OK();
904   }
905 
906   // If this is a trivial Identity node with a constant input, just route the
907   // input around it.
908   if (IsIdentity(*node)) {
909     NodeDef* input = node_map_->GetNode(node->input(0));
910     if (IsReallyConstant(*input)) {
911       std::vector<int> inputs_to_forward;
912       std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
913       graph_modified_ = ForwardInputs(node, inputs_to_forward);
914       return Status::OK();
915     }
916   }
917   // Repurpose the existing node to be the constant.
918   // Device placement is preserved.
919   TensorProto value_copy = output[0].value();
920   return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
921                                             node, graph_);
922 }
923 
MaterializeConstants(const GraphProperties & properties)924 Status ConstantFolding::MaterializeConstants(
925     const GraphProperties& properties) {
926   const int node_count = graph_->node_size();
927   for (int i = 0; i < node_count; ++i) {
928     NodeDef& node = *graph_->mutable_node(i);
929     const string& op = node.op();
930     if (op == "BroadcastGradientArgs") {
931       TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
932     } else if (IsReduction(node)) {
933       TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
934     } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
935       TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
936     } else {
937       TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
938     }
939   }
940   return Status::OK();
941 }
942 
IsFoldable(const NodeDef & node,const GraphProperties * properties)943 bool ConstantFolding::IsFoldable(const NodeDef& node,
944                                  const GraphProperties* properties) {
945   string key = strings::StrCat(node.name(), "/", node.op());
946   auto it = maybe_foldable_nodes_.find(key);
947   if (it == maybe_foldable_nodes_.end()) {
948     it = maybe_foldable_nodes_
949              .emplace(std::move(key), MaybeFoldable(node, properties))
950              .first;
951   }
952   if (!it->second) {
953     return false;
954   } else {
955     return IsFoldableUncached(node, properties);
956   }
957 }
958 
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const959 bool ConstantFolding::IsFoldableUncached(
960     const NodeDef& node, const GraphProperties* properties) const {
961   // Folding not applicable to ops with no inputs.
962   if (node.input().empty()) {
963     return false;
964   }
965   // We can only fold nodes if all their inputs are known statically, except in
966   // the case of a merge node that propagate the first inputs that becomes
967   // available, and therefore only requires a single constant input to be
968   // foldable.
969   bool merge_has_constant_input = false;
970   const bool is_merge = IsMerge(node);
971   for (const auto& input : node.input()) {
972     if (IsControlInput(input)) {
973       continue;
974     }
975     const NodeDef* input_node = node_map_->GetNode(input);
976     if (!input_node) {
977       return false;
978     }
979     bool is_const = IsReallyConstant(*input_node);
980     if (is_const) {
981       // Don't fold strings constants for now since this causes problems with
982       // checkpointing.
983       if (input_node->attr().count("dtype") == 0 ||
984           input_node->attr().at("dtype").type() == DT_STRING) {
985         return false;
986       }
987       // Special case: If a Merge node has at least one constant input that
988       // does not depend on a control input, we can fold it.
989       merge_has_constant_input |= !HasControlInputs(*input_node);
990     } else if (!is_merge) {
991       return false;
992     }
993   }
994   if (is_merge && !merge_has_constant_input) return false;
995   if (disable_compressed_tensor_optimization_ &&
996       (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
997     return false;
998 
999   // If we know the output shapes, make sure that the outputs are small enough
1000   // to materialize.
1001   if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1002     const std::vector<OpInfo::TensorProperties>& input_props =
1003         properties->GetInputProperties(node.name());
1004     const std::vector<OpInfo::TensorProperties>& output_props =
1005         properties->GetOutputProperties(node.name());
1006     // Compute total size of inputs.
1007     int64 input_size_bytes = 0;
1008     for (const auto& input_prop : input_props) {
1009       const PartialTensorShape input_shape(input_prop.shape());
1010       if (input_shape.IsFullyDefined()) {
1011         input_size_bytes +=
1012             input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1013       }
1014     }
1015     for (const auto& output_prop : output_props) {
1016       const PartialTensorShape output_shape(output_prop.shape());
1017       if (output_shape.IsFullyDefined()) {
1018         const int64 num_bytes =
1019             output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1020         if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1021           // Do not fold nodes if the in-memory size of output is too large.
1022           // Notice that this is not exactly the same check used in
1023           // CreateNodeDef() where the actual encoded size is checked.
1024           return false;
1025         }
1026       }
1027     }
1028   }
1029 
1030   return true;
1031 }
1032 
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1033 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1034                                     const GraphProperties* properties) const {
1035   // Skip constants, they're already folded
1036   if (IsConstant(node)) {
1037     return false;
1038   }
1039   // Don't fold stateful ops such as TruncatedNormal.
1040   if (!IsFreeOfSideEffect(node)) {
1041     return false;
1042   }
1043 
1044   // Skips nodes that must be preserved except allowlisted nodes.
1045   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1046       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1047     return false;
1048   }
1049 
1050   // Skip control flow nodes, they can't be folded.
1051   if (ModifiesFrameInfo(node)) {
1052     return false;
1053   }
1054 
1055   // Skips ops that don't benefit from folding.
1056   if (IsPlaceholder(node)) {
1057     return false;
1058   }
1059   // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1060   // have a valid output when executed.
1061   if (IsFakeParam(node)) {
1062     return false;
1063   }
1064 
1065   if (node.op() == "AccumulateNV2") {
1066     return false;
1067   }
1068   // Removing LoopCond nodes can screw up the partitioner.
1069   if (node.op() == "LoopCond") {
1070     return false;
1071   }
1072 
1073   if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1074     return false;
1075   }
1076 
1077   const string& op = node.op();
1078   if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1079       op.find("Reader") != string::npos) {
1080     return false;
1081   }
1082   if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1083     return false;
1084   }
1085 
1086   // Don't fold nodes that contain TPU attributes.
1087   // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1088   // properly forward custom attributes, b/119051778.
1089   if (HasTPUAttributes(node)) {
1090     return false;
1091   }
1092 
1093   const OpDef* op_def = nullptr;
1094   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1095   if (!status.ok()) {
1096     return false;
1097   }
1098   // Don't fold ops without outputs.
1099   if (op_def->output_arg_size() == 0) {
1100     return false;
1101   }
1102   // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1103   // TODO(rmlarsen): Only do this for XLA_* devices.
1104   for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1105     if (output_arg.type() == DT_VARIANT) {
1106       return false;
1107     }
1108   }
1109 
1110   // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1111   // Such nodes could be introduced by an earlier constant folding pass and are
1112   // preserved in case users want to fetch their values; re-processing them
1113   // would lead to an error of adding a duplicated node to graph.
1114   const auto& outputs = node_map_->GetOutputs(node.name());
1115   if (outputs.empty() &&
1116       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1117     return false;
1118   }
1119   return true;
1120 }
1121 
1122 namespace {
1123 
1124 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME)     \
1125   case DTYPE:                                      \
1126     t->add_##NAME##_val(static_cast<TYPE>(value)); \
1127     break;
1128 
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1129 Status CreateConstantTensorAttrValue(DataType type, double value,
1130                                      const TensorShapeProto& shape,
1131                                      AttrValue* attr_tensor) {
1132   TensorProto* t = attr_tensor->mutable_tensor();
1133   t->set_dtype(type);
1134   *t->mutable_tensor_shape() = shape;
1135   switch (type) {
1136     case DT_HALF:
1137       t->add_half_val(static_cast<Eigen::half>(value).x);
1138       break;
1139     case DT_BFLOAT16:
1140       t->add_half_val(static_cast<bfloat16>(value).value);
1141       break;
1142       SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1143       SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1144       SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
1145       SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
1146       SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1147       SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1148       SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1149       SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1150       SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1151       SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1152       SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1153       SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1154       SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1155       SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1156       SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1157       SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1158     default:
1159       return errors::InvalidArgument(
1160           "Unsupported type in CreateConstantTensorAttrValue: ",
1161           DataTypeString(type));
1162   }
1163   return Status::OK();
1164 }
1165 
1166 #undef SET_TENSOR_CAL_CASE
1167 
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1168 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1169                                     const GraphProperties& properties) {
1170   DataType dtype = DT_INVALID;
1171   if (node.attr().count("T") == 1) {
1172     dtype = node.attr().at("T").type();
1173   } else if (node.attr().count("dtype") == 1) {
1174     dtype = node.attr().at("dtype").type();
1175   } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1176     dtype = DT_BOOL;
1177   } else {
1178     auto output_props = properties.GetOutputProperties(node.name());
1179     if (!output_props.empty()) {
1180       dtype = output_props[0].dtype();
1181     }
1182   }
1183   return dtype;
1184 }
1185 
1186 // Checks whether the shape of the const input of the Mul op is valid to perform
1187 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1188 bool IsValidConstShapeForMulConvPushDown(
1189     const string& data_format, const TensorShapeProto& filter_shape,
1190     const TensorShapeProto& mul_const_input_shape) {
1191   // If the const is a scalar, or it has fewer or same number of dimensions
1192   // than the filter and it only has single element, the optimization should
1193   // work.
1194   if (mul_const_input_shape.dim_size() <=
1195           static_cast<int>(data_format.size()) &&
1196       TensorShape(mul_const_input_shape).num_elements() == 1) {
1197     return true;
1198   }
1199 
1200   // Otherwise, check the eligibility according to data format.
1201   if (data_format == "NHWC" || data_format == "NDHWC") {
1202     TensorShapeProto new_filter_shape;
1203     if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1204                              &new_filter_shape)) {
1205       return false;
1206     }
1207     if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1208       return false;
1209     }
1210     // Only the last dimension could be larger than one, since broadcasting over
1211     // the last dimension (the output channel) will result in invalid filter.
1212     for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1213       if (mul_const_input_shape.dim(i).size() > 1) return false;
1214     }
1215     return true;
1216   } else if (data_format == "NCHW" || data_format == "NCDHW") {
1217     // TODO(laigd): support NCHW and NCDHW (b/111214513).
1218     return false;
1219   }
1220   return false;
1221 }
1222 
1223 }  // namespace
1224 
1225 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1226 Status ConstantFolding::CreateNodeDef(const string& name,
1227                                       const TensorValue& tensor, NodeDef* node,
1228                                       size_t original_size) {
1229   node->set_name(name);
1230   node->set_op("Const");
1231 
1232   AttrValue attr_type;
1233   attr_type.set_type(tensor->dtype());
1234   node->mutable_attr()->insert({"dtype", attr_type});
1235 
1236   AttrValue attr_tensor;
1237   TensorProto* t = attr_tensor.mutable_tensor();
1238   bool optimized = false;
1239   size_t encoded_size;
1240   // Use the packed representation whenever possible to avoid generating large
1241   // graphdefs. Moreover, avoid repeating the last values if they're equal.
1242   if (tensor->NumElements() > 4) {
1243 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE)                      \
1244   {                                                                            \
1245     const auto* val_ptr = tensor->flat<TYPE>().data();                         \
1246     auto last = *val_ptr;                                                      \
1247     int64 last_index = 0;                                                      \
1248     for (int64 i = 0; i < tensor->NumElements(); ++i) {                        \
1249       TYPE cur = *val_ptr++;                                                   \
1250       if (PackedValuesNotEqual(cur, last)) {                                   \
1251         last = cur;                                                            \
1252         last_index = i;                                                        \
1253       }                                                                        \
1254     }                                                                          \
1255     encoded_size = (last_index + 1) * sizeof(FIELDTYPE);                       \
1256     if (encoded_size < kint32max) {                                            \
1257       optimized = true;                                                        \
1258       t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1);                 \
1259       const auto* src_ptr = tensor->flat<TYPE>().data();                       \
1260       auto* dst_ptr =                                                          \
1261           t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1262       std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr);                   \
1263     }                                                                          \
1264   }                                                                            \
1265   break
1266 
1267     switch (tensor->dtype()) {
1268       case DT_FLOAT:
1269         POPULATE_TENSOR_PROTO(tensor, t, float, float);
1270       case DT_DOUBLE:
1271         POPULATE_TENSOR_PROTO(tensor, t, double, double);
1272       case DT_INT64:
1273         POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
1274       case DT_UINT64:
1275         POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1276       case DT_INT32:
1277         POPULATE_TENSOR_PROTO(tensor, t, int32, int);
1278       case DT_UINT32:
1279         POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1280       case DT_INT16:
1281         POPULATE_TENSOR_PROTO(tensor, t, int16, int);
1282       case DT_UINT16:
1283         POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1284       case DT_INT8:
1285         POPULATE_TENSOR_PROTO(tensor, t, int8, int);
1286       case DT_UINT8:
1287         POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1288       case DT_BOOL:
1289         POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1290       default:
1291         /* Do nothing. */
1292         break;
1293     }
1294   }
1295   if (optimized) {
1296     // Also specify type and shape.
1297     t->set_dtype(tensor->dtype());
1298     tensor->shape().AsProto(t->mutable_tensor_shape());
1299   } else {
1300     // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1301     // DT_QUINT8
1302     tensor->AsProtoTensorContent(t);
1303     encoded_size = t->tensor_content().size();
1304   }
1305   node->mutable_attr()->insert({"value", attr_tensor});
1306 
1307   if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1308     return errors::InvalidArgument(
1309         strings::StrCat("Can't fold ", name, ", its size would be too large (",
1310                         encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1311   }
1312   return Status::OK();
1313 }
1314 
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1315 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1316                                      const TensorVector& inputs,
1317                                      TensorVector* output) const {
1318   return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1319                                               resource_mgr_.get(), output);
1320 }
1321 
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1322 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1323                                             std::vector<NodeDef>* outputs,
1324                                             bool* result_too_large) {
1325   TensorVector inputs;
1326   TensorVector output_tensors;
1327   auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1328     for (const auto& input : inputs) {
1329       delete input.tensor;
1330     }
1331     for (const auto& output : output_tensors) {
1332       if (output.tensor) {
1333         delete output.tensor;
1334       }
1335     }
1336   });
1337 
1338   size_t total_inputs_size = 0;
1339   for (const auto& input : node.input()) {
1340     const TensorId input_tensor = ParseTensorName(input);
1341     if (input_tensor.index() < 0) {
1342       // Control dependency
1343       break;
1344     }
1345     const NodeDef* input_node = node_map_->GetNode(input);
1346     if (!IsReallyConstant(*input_node)) {
1347       return Status(error::INVALID_ARGUMENT,
1348                     strings::StrCat("Can't fold ", node.name(), ", its ", input,
1349                                     " isn't constant"));
1350     }
1351     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1352     const TensorProto& raw_val = input_node->attr().at("value").tensor();
1353     Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1354     CHECK(value->FromProto(raw_val))
1355         << "Unable to make Tensor from proto for " << node.name()
1356         << " with shape " << raw_val.tensor_shape().DebugString();
1357     inputs.emplace_back(value);
1358     total_inputs_size += value->TotalBytes();
1359   }
1360 
1361   TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1362   if (output_tensors.empty()) {
1363     return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1364   }
1365 
1366   outputs->resize(output_tensors.size());
1367   for (size_t i = 0; i < output_tensors.size(); i++) {
1368     string node_name = OptimizedNodeName(node, "-folded");
1369     if (output_tensors.size() > 1) {
1370       node_name = strings::StrCat(node_name, "-", i);
1371     }
1372     if (output_tensors[i].tensor) {
1373       Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1374                                total_inputs_size);
1375       if (!s.ok()) {
1376         *result_too_large = true;
1377         return s;
1378       }
1379     } else {
1380       // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1381       // switch that's not selected by the switch predicate).
1382       outputs->at(i) = NodeDef();
1383     }
1384   }
1385   return Status::OK();
1386 }
1387 
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1388 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1389   // Merge nodes are special, in the sense that they execute as soon as one of
1390   // their input is ready. We can therefore fold a merge node iff it has at
1391   // least one constant input without control dependency.
1392   // We still need to ensure that the nodes in the fanin of the merge node are
1393   // scheduled. We'll therefore add a control dependency from the merge node
1394   // to the folded constant. We end up with:
1395   //  * the merge node and its inputs are preserved as is
1396   //  * a new constant node C1, driven by the merge node through a control
1397   //  dependency, initialized to the value of the folded input
1398   //  * a new constant node C2, driven by the merge node through a control
1399   //  dependency, initialized to the index of the folded input
1400   //  * the fanout of the merge nodes is rewired to be driven by either C1 or
1401   //  C2.
1402   for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1403     const auto& input = node->input(input_index);
1404     if (IsControlInput(input)) {
1405       // Try the next input.
1406       continue;
1407     }
1408     NodeDef* input_node = node_map_->GetNode(input);
1409     if (!IsReallyConstant(*input_node)) {
1410       continue;
1411     }
1412     bool valid_input = true;
1413     for (const string& fanin_of_input : input_node->input()) {
1414       if (IsControlInput(fanin_of_input)) {
1415         valid_input = false;
1416         break;
1417       }
1418     }
1419     if (!valid_input) {
1420       // Try the next input
1421       continue;
1422     }
1423 
1424     string const_out_name = OptimizedNodeName(*node, "_const");
1425     string const_index_name = OptimizedNodeName(*node, "_index");
1426     if (node_map_->GetNode(const_out_name) ||
1427         node_map_->GetNode(const_index_name)) {
1428       // Intended name already exists.
1429       return errors::AlreadyExists(
1430           strings::StrCat(const_out_name, " or ", const_index_name,
1431                           " already present in the graph"));
1432     }
1433 
1434     NodeDef* const_out = output_graph->add_node();
1435     *const_out = *input_node;
1436     const_out->set_name(const_out_name);
1437     const_out->set_device(node->device());
1438     *const_out->add_input() = AsControlDependency(*node);
1439     node_map_->AddNode(const_out->name(), const_out);
1440     node_map_->AddOutput(node->name(), const_out->name());
1441 
1442     NodeDef* const_index = output_graph->add_node();
1443     const_index->set_op("Const");
1444     Tensor index(DT_INT32, TensorShape({}));
1445     index.flat<int32>()(0) = input_index;
1446     (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1447     index.AsProtoTensorContent(
1448         (*const_index->mutable_attr())["value"].mutable_tensor());
1449     const_index->set_name(const_index_name);
1450     const_index->set_device(node->device());
1451     *const_index->add_input() = AsControlDependency(*node);
1452     node_map_->AddNode(const_index->name(), const_index);
1453     node_map_->AddOutput(node->name(), const_index->name());
1454 
1455     // We make a copy because we mutate the nodes.
1456     auto outputs = node_map_->GetOutputs(node->name());
1457     for (NodeDef* output : outputs) {
1458       for (int i = 0; i < output->input_size(); i++) {
1459         int port;
1460         string node_name = ParseNodeName(output->input(i), &port);
1461         if (node_name == node->name()) {
1462           if (port == 0) {
1463             *output->mutable_input(i) = const_out->name();
1464             node_map_->AddOutput(const_out->name(), output->name());
1465           } else if (port == 1) {
1466             *output->mutable_input(i) = const_index->name();
1467             node_map_->AddOutput(const_index->name(), output->name());
1468           } else {
1469             // This is a control dependency (or an invalid edge since the
1470             // merge node has only 2 outputs): preserve them.
1471           }
1472         }
1473       }
1474     }
1475     return Status::OK();
1476   }
1477   return Status::OK();
1478 }
1479 
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1480 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1481                                  bool* result_too_large) {
1482   *result_too_large = false;
1483   if (IsMerge(*node)) {
1484     return FoldMergeNode(node, output_graph);
1485   }
1486 
1487   std::vector<NodeDef> const_nodes;
1488   TF_RETURN_IF_ERROR(
1489       EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1490   VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1491 
1492   NodeDef* constant_output = nullptr;
1493   for (int i = 0, end = const_nodes.size(); i < end; i++) {
1494     NodeDef* const_node = &const_nodes[i];
1495     VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1496     if (const_node->name().empty()) {
1497       // Dead output: we can't create a constant to encode its value, so we'll
1498       // just skip it. We'll preserve the edges that originate from that
1499       // output below to preserve the overall behavior of the graph wrt dead
1500       // edges.
1501       continue;
1502     }
1503 
1504     // Returns `true` iff `const_node` already has control input named `input`.
1505     const auto is_duplicate_control_input = [&](const string& input) -> bool {
1506       auto it = absl::c_find(const_node->input(), input);
1507       return it != const_node->input().end();
1508     };
1509 
1510     // Forward control dependencies.
1511     for (const string& input : node->input()) {
1512       // Forward control dependencies from folded node.
1513       if (IsControlInput(input)) {
1514         if (!is_duplicate_control_input(input)) {
1515           *const_node->add_input() = input;
1516         }
1517       }
1518 
1519       // Forward control dependencies from constant inputs to folded node.
1520       if (!IsControlInput(input)) {
1521         NodeDef* input_node = node_map_->GetNode(input);
1522         for (const string& fanin_of_input : input_node->input()) {
1523           if (!is_duplicate_control_input(fanin_of_input)) {
1524             *const_node->add_input() = fanin_of_input;
1525           }
1526         }
1527       }
1528     }
1529 
1530     // We rewrite the existing node if it only has a single output, and
1531     // create new nodes otherwise.
1532     if (const_nodes.size() == 1) {
1533       node->set_op("Const");
1534       // Note we need to clear the inputs in NodeMap before we clear the inputs
1535       // in the node, otherwise NodeMap would see empty inputs and effectively
1536       // does nothing.
1537       node_map_->RemoveInputs(node->name());
1538       node->clear_input();
1539       *node->mutable_input() = const_node->input();
1540       for (const auto& input : node->input()) {
1541         node_map_->AddOutput(NodeName(input), node->name());
1542       }
1543       *node->mutable_attr() = const_node->attr();
1544       break;
1545     } else {
1546       if (node_map_->GetNode(const_node->name())) {
1547         // Intended name already exists.
1548         return errors::AlreadyExists(strings::StrCat(
1549             const_node->name(), " already present in the graph"));
1550       }
1551       NodeDef* added_node = output_graph->add_node();
1552       *added_node = *const_node;
1553       added_node->set_device(node->device());
1554       node_map_->AddNode(added_node->name(), added_node);
1555       for (const auto& input : added_node->input()) {
1556         node_map_->AddOutput(NodeName(input), added_node->name());
1557       }
1558       // All the constant nodes encoding output values have the same control
1559       // dependencies (since these are the control dependencies of the node
1560       // we're trying to fold). Record one such constant node.
1561       constant_output = added_node;
1562     }
1563   }
1564 
1565   if (const_nodes.size() > 1) {
1566     // We make a copy because we mutate the nodes.
1567     auto outputs = node_map_->GetOutputs(node->name());
1568     for (NodeDef* output : outputs) {
1569       for (int i = 0; i < output->input_size(); i++) {
1570         int port;
1571         string node_name = ParseNodeName(output->input(i), &port);
1572         if (node_name == node->name()) {
1573           if (port < 0) {
1574             // Propagate control dependencies if possible. If not, we'll just
1575             // preserve the existing control dependencies.
1576             if (constant_output != nullptr) {
1577               node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1578                                      constant_output->name());
1579               *output->mutable_input(i) = AsControlDependency(*constant_output);
1580             }
1581           } else if (port < static_cast<int>(const_nodes.size()) &&
1582                      !const_nodes[port].name().empty()) {
1583             // Replace alive outputs with the corresponding constant.
1584             node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1585                                    const_nodes[port].name());
1586             *output->mutable_input(i) = const_nodes[port].name();
1587           } else {
1588             // Leave this edge alone.
1589             VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1590                     << "[" << node->op() << "] to " << output->name() << ":"
1591                     << i << "[" << output->op() << "]";
1592           }
1593         }
1594       }
1595     }
1596     outputs = node_map_->GetOutputs(node->name());
1597     if (outputs.empty() && has_fetch_ &&
1598         nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1599       node_map_->RemoveInputs(node->name());
1600       node->clear_input();
1601     }
1602   }
1603   return Status::OK();
1604 }
1605 
FoldGraph(const GraphProperties & properties,GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1606 Status ConstantFolding::FoldGraph(
1607     const GraphProperties& properties, GraphDef* output,
1608     absl::flat_hash_set<string>* nodes_to_not_simplify) {
1609   std::unordered_set<string> processed_nodes;
1610   std::deque<NodeDef*> queue;
1611   for (int i = 0; i < graph_->node_size(); i++) {
1612     bool foldable = IsFoldable(graph_->node(i), &properties);
1613     VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
1614     if (foldable) {
1615       queue.push_back(graph_->mutable_node(i));
1616     }
1617   }
1618   while (!queue.empty()) {
1619     NodeDef* node = queue.front();
1620     queue.pop_front();
1621     if (processed_nodes.count(node->name())) {
1622       continue;
1623     }
1624     // We need to record a copy of output nodes before FoldNode() modifies it.
1625     // We also need to ensure that the fanout is sorted deterministically.
1626     std::vector<NodeDef*> fanout =
1627         node_map_->GetOutputsOrderedByNodeName(node->name());
1628     bool result_too_large = false;
1629     Status s = FoldNode(node, output, &result_too_large);
1630     processed_nodes.insert(node->name());
1631     if (!s.ok()) {
1632       VLOG(1) << "Failed to fold node " << node->DebugString()
1633               << "\nError message: " << s;
1634       if (result_too_large) {
1635         nodes_to_not_simplify->emplace(node->name());
1636       }
1637     } else {
1638       for (auto& output : fanout) {
1639         if (IsFoldable(*output, &properties)) {
1640           queue.push_back(output);
1641         }
1642       }
1643     }
1644   }
1645 
1646   // Delete the newly created nodes that don't feed anything.
1647   std::vector<int> nodes_to_delete;
1648   for (int i = 0; i < output->node_size(); i++) {
1649     const auto& fanout = node_map_->GetOutputs(output->node(i).name());
1650     if (fanout.empty()) nodes_to_delete.push_back(i);
1651   }
1652   EraseNodesFromGraph(std::move(nodes_to_delete), output);
1653 
1654   for (int i = 0; i < graph_->node_size(); ++i) {
1655     NodeDef* node = graph_->mutable_node(i);
1656     // If no fetch nodes is provided, we conservatively
1657     // move all nodes in the original graph to the output, in case users need
1658     // to fetch their values.
1659     const auto& fanout = node_map_->GetOutputs(node->name());
1660     if (!fanout.empty() || !has_fetch_ ||
1661         nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1662       *(output->add_node()) = std::move(*node);
1663     }
1664   }
1665   return Status::OK();
1666 }
1667 
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1668 bool ConstantFolding::IsSimplifiableReshape(
1669     const NodeDef& node, const GraphProperties& properties) const {
1670   if (!IsReshape(node)) {
1671     return false;
1672   }
1673   CHECK_LE(2, node.input_size());
1674   const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1675   if (!IsReallyConstant(*new_shape)) {
1676     return false;
1677   }
1678   TensorVector outputs;
1679   auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1680     for (const auto& output : outputs) {
1681       delete output.tensor;
1682     }
1683   });
1684 
1685   Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1686   if (!s.ok()) {
1687     return false;
1688   }
1689   CHECK_EQ(1, outputs.size());
1690 
1691   const std::vector<OpInfo::TensorProperties>& props =
1692       properties.GetInputProperties(node.name());
1693   if (props.empty()) {
1694     return false;
1695   }
1696   const OpInfo::TensorProperties& prop = props[0];
1697   if (prop.dtype() == DT_INVALID) {
1698     return false;
1699   }
1700   const PartialTensorShape shape(prop.shape());
1701   if (!shape.IsFullyDefined()) {
1702     return false;
1703   }
1704 
1705   PartialTensorShape new_dims;
1706   if (outputs[0]->dtype() == DT_INT32) {
1707     std::vector<int32> shp;
1708     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1709       int32 dim = outputs[0]->flat<int32>()(i);
1710       shp.push_back(dim);
1711     }
1712     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1713   } else {
1714     std::vector<int64> shp;
1715     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1716       int64 dim = outputs[0]->flat<int64>()(i);
1717       shp.push_back(dim);
1718     }
1719     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1720   }
1721 
1722   return shape.IsCompatibleWith(new_dims);
1723 }
1724 
1725 #define IS_VALUE_CASE(DTYPE, VALUE)                   \
1726   case DTYPE:                                         \
1727     return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1728         node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1729 
1730 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1731 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1732 
IsOnes(const NodeDef & node) const1733 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1734   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1735     return false;
1736   }
1737   if (IsOnesLike(node)) return true;
1738   if (IsZerosLike(node)) return false;
1739   if (node.op() == "Fill") {
1740     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1741     return values != nullptr && IsOnes(*values);
1742   }
1743   if (node.op() != "Const") return false;
1744   if (node.attr().count("dtype") == 0) return false;
1745   const auto dtype = node.attr().at("dtype").type();
1746   switch (dtype) {
1747     IS_ONES_CASE(DT_BOOL);
1748     IS_ONES_CASE(DT_HALF);
1749     IS_ONES_CASE(DT_BFLOAT16);
1750     IS_ONES_CASE(DT_FLOAT);
1751     IS_ONES_CASE(DT_DOUBLE);
1752     IS_ONES_CASE(DT_COMPLEX64);
1753     IS_ONES_CASE(DT_COMPLEX128);
1754     IS_ONES_CASE(DT_UINT8);
1755     IS_ONES_CASE(DT_INT8);
1756     IS_ONES_CASE(DT_UINT16);
1757     IS_ONES_CASE(DT_INT16);
1758     IS_ONES_CASE(DT_INT32);
1759     IS_ONES_CASE(DT_INT64);
1760     IS_ONES_CASE(DT_QINT32);
1761     IS_ONES_CASE(DT_QINT16);
1762     IS_ONES_CASE(DT_QUINT16);
1763     IS_ONES_CASE(DT_QINT8);
1764     IS_ONES_CASE(DT_QUINT8);
1765     default:
1766       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1767       return false;
1768   }
1769   return false;
1770 }
1771 
IsZeros(const NodeDef & node) const1772 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1773   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1774     return false;
1775   }
1776   if (IsOnesLike(node)) return false;
1777   if (IsZerosLike(node)) return true;
1778   if (node.op() == "Fill") {
1779     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1780     return values != nullptr && IsZeros(*values);
1781   }
1782   if (!IsConstant(node)) return false;
1783   if (node.attr().count("dtype") == 0) return false;
1784   const auto dtype = node.attr().at("dtype").type();
1785   switch (dtype) {
1786     IS_ZEROS_CASE(DT_BOOL);
1787     IS_ZEROS_CASE(DT_HALF);
1788     IS_ZEROS_CASE(DT_BFLOAT16);
1789     IS_ZEROS_CASE(DT_FLOAT);
1790     IS_ZEROS_CASE(DT_DOUBLE);
1791     IS_ZEROS_CASE(DT_COMPLEX64);
1792     IS_ZEROS_CASE(DT_COMPLEX128);
1793     IS_ZEROS_CASE(DT_UINT8);
1794     IS_ZEROS_CASE(DT_INT8);
1795     IS_ZEROS_CASE(DT_UINT16);
1796     IS_ZEROS_CASE(DT_INT16);
1797     IS_ZEROS_CASE(DT_INT32);
1798     IS_ZEROS_CASE(DT_INT64);
1799     IS_ZEROS_CASE(DT_QINT32);
1800     IS_ZEROS_CASE(DT_QINT16);
1801     IS_ZEROS_CASE(DT_QUINT16);
1802     IS_ZEROS_CASE(DT_QINT8);
1803     IS_ZEROS_CASE(DT_QUINT8);
1804     default:
1805       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1806       return false;
1807   }
1808   return false;
1809 }
1810 
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1811 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1812     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1813     GraphDef* graph) {
1814   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1815   if (dtype == DT_INVALID) {
1816     return false;
1817   }
1818   const PartialTensorShape shape(
1819       properties.GetOutputProperties(node->name())[0].shape());
1820   if (!shape.IsFullyDefined()) {
1821     return false;
1822   }
1823   // Create constant node with shape.
1824   const string const_name = OptimizedNodeName(
1825       *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1826   if (node_map_->GetNode(const_name) != nullptr) {
1827     return false;
1828   }
1829 
1830   Tensor shape_t;
1831   if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1832     return false;
1833   }
1834   NodeDef tmp;
1835   if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1836     return false;
1837   }
1838   NodeDef* const_node = graph->add_node();
1839   const_node->Swap(&tmp);
1840   const_node->set_device(node->device());
1841   node_map_->AddNode(const_name, const_node);
1842   for (int i = 0; i < node->input_size(); ++i) {
1843     if (i != input_to_broadcast) {
1844       // Add a control input on the unused input.
1845       string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1846                                              node_map_.get());
1847       *const_node->add_input() = ctrl_dep;
1848       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1849     }
1850   }
1851 
1852   // Rewrite `node` in-place to BroadcastTo.
1853   node->set_op("BroadcastTo");
1854   EraseRegularNodeAttributes(node);
1855   (*node->mutable_attr())["T"].set_type(dtype);
1856   (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1857   // Set the designated input to BroadcastTo.
1858   node->mutable_input()->SwapElements(0, input_to_broadcast);
1859   // Keep all other inputs as control dependencies.
1860   for (int i = 1; i < node->input_size(); ++i) {
1861     if (IsControlInput(node->input(i))) {
1862       break;
1863     }
1864     const string ctrl_dep =
1865         AddControlDependency(node->input(i), graph, node_map_.get());
1866     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1867     node->set_input(i, ctrl_dep);
1868   }
1869   // Add the shape argument.
1870   *node->add_input() = const_node->name();
1871   node_map_->AddOutput(const_name, node->name());
1872   node->mutable_input()->SwapElements(1, node->input_size() - 1);
1873   return true;
1874 }
1875 
1876 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1877 void ConstantFolding::ReplaceOperationWithIdentity(
1878     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1879     GraphDef* graph) {
1880   if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1881   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1882   if (dtype == DT_INVALID) return;
1883 
1884   node->set_op("Identity");
1885   EraseRegularNodeAttributes(node);
1886   (*node->mutable_attr())["T"].set_type(dtype);
1887   // Propagate the designated input through the identity.
1888   node->mutable_input()->SwapElements(0, input_to_forward);
1889   // Add all other inputs as control dependencies.
1890   for (int i = 1; i < node->input_size(); ++i) {
1891     if (IsControlInput(node->input(i))) {
1892       break;
1893     }
1894     const string ctrl_dep =
1895         AddControlDependency(node->input(i), graph, node_map_.get());
1896     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1897     node->set_input(i, ctrl_dep);
1898   }
1899   graph_modified_ = true;
1900 }
1901 
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1902 void ConstantFolding::ReplaceOperationWithSnapshot(
1903     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1904     GraphDef* graph) {
1905   // If the graph contains no ops that mutate their inputs, we can
1906   // use Identity instead of Snapshot.
1907   if (!graph_contains_assign_or_inplace_op_) {
1908     ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1909     return;
1910   }
1911 
1912   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1913   if (dtype == DT_INVALID) return;
1914 
1915   node->set_op("Snapshot");
1916   EraseRegularNodeAttributes(node);
1917   (*node->mutable_attr())["T"].set_type(dtype);
1918   // Propagate the designated input through the Snapshot.
1919   node->mutable_input()->SwapElements(0, input_to_forward);
1920   // Add all other inputs as control dependencies.
1921   for (int i = 1; i < node->input_size(); ++i) {
1922     if (IsControlInput(node->input(i))) {
1923       break;
1924     }
1925     const string ctrl_dep =
1926         AddControlDependency(node->input(i), graph, node_map_.get());
1927     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1928     node->set_input(i, ctrl_dep);
1929   }
1930   graph_modified_ = true;
1931 }
1932 
1933 // Replace a node with NoOp. Change all inputs to control dependencies.
1934 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1935 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1936                                                GraphProperties* properties,
1937                                                GraphDef* graph) {
1938   if (HasRegularOutputs(*node, *node_map_)) return;
1939   node->set_op("NoOp");
1940   EraseRegularNodeAttributes(node);
1941   EraseNodeOutputAttributes(node);
1942   // Erase attributes that describe output properties.
1943   properties->ClearOutputProperties(node->name());
1944   // Change all inputs to control dependencies.
1945   for (int i = 0; i < node->input_size(); ++i) {
1946     if (IsControlInput(node->input(i))) {
1947       break;
1948     }
1949     const string ctrl_dep =
1950         AddControlDependency(node->input(i), graph, node_map_.get());
1951     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1952     node->set_input(i, ctrl_dep);
1953   }
1954   DedupControlInputs(node);
1955   graph_modified_ = true;
1956 }
1957 
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1958 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
1959     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1960     GraphDef* graph) {
1961   if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
1962                                        graph)) {
1963     return;
1964   }
1965   graph_modified_ = true;
1966 }
1967 
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1968 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1969                                                         GraphDef* graph) {
1970   node->set_op("Reciprocal");
1971   node->mutable_input()->SwapElements(0, 1);
1972   const string ctrl_dep =
1973       AddControlDependency(node->input(1), graph, node_map_.get());
1974   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1975   node->set_input(1, ctrl_dep);
1976   graph_modified_ = true;
1977 }
1978 
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1979 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1980                                                            GraphDef* graph) {
1981   node->set_op("Neg");
1982   node->mutable_input()->SwapElements(0, 1);
1983   const string ctrl_dep =
1984       AddControlDependency(node->input(1), graph, node_map_.get());
1985   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1986   node->set_input(1, ctrl_dep);
1987   graph_modified_ = true;
1988 }
1989 
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)1990 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
1991                                                            TensorProto* value,
1992                                                            NodeDef* node,
1993                                                            GraphDef* graph) {
1994   if (dtype == DT_VARIANT) return Status::OK();
1995   node->set_op("Const");
1996   EraseRegularNodeAttributes(node);
1997   (*node->mutable_attr())["dtype"].set_type(dtype);
1998   (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
1999   // Convert all inputs to control dependencies.
2000   for (int i = 0; i < node->input_size(); ++i) {
2001     if (IsControlInput(node->input(i))) {
2002       break;
2003     }
2004     const string ctrl_dep =
2005         AddControlDependency(node->input(i), graph, node_map_.get());
2006     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2007     node->set_input(i, ctrl_dep);
2008   }
2009   DedupControlInputs(node);
2010   graph_modified_ = true;
2011   return Status::OK();
2012 }
2013 
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2014 Status ConstantFolding::ReplaceOperationWithConstant(
2015     double value, const GraphProperties& properties,
2016     const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2017   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2018   if (dtype == DT_VARIANT) return Status::OK();
2019   AttrValue tensor_attr;
2020   Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2021   if (!s.ok()) {
2022     // Fail gracefully without mutating the graph.
2023     VLOG(1) << "Failed to replace node " << node->name() << " of type "
2024             << DataTypeString(dtype) << " with constant tensor of value "
2025             << value;
2026     return Status::OK();
2027   }
2028   return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2029                                             node, graph);
2030 }
2031 
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2032 Status ConstantFolding::SimplifyGraph(
2033     bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
2034     absl::flat_hash_set<string>* nodes_to_not_simplify) {
2035   for (int i = 0; i < optimized_graph->node_size(); ++i) {
2036     NodeDef* node = optimized_graph->mutable_node(i);
2037     // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2038     // generalize to only restrict certain simplifications.
2039     if (nodes_to_not_simplify->find(node->name()) ==
2040         nodes_to_not_simplify->end()) {
2041       if (HasTPUAttributes(*node)) {
2042         nodes_to_not_simplify->insert(node->name());
2043         continue;
2044       }
2045 
2046       TF_RETURN_IF_ERROR(
2047           SimplifyNode(use_shape_info, node, optimized_graph, properties));
2048     }
2049   }
2050   return Status::OK();
2051 }
2052 
2053 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2054   TF_RETURN_IF_ERROR(EXPR);               \
2055   if (graph_modified_) return Status::OK()
2056 
2057 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2058   graph_modified_ = EXPR;                \
2059   if (graph_modified_) return Status::OK()
2060 
2061 #define RETURN_IF_MODIFIED(EXPR) \
2062   EXPR;                          \
2063   if (graph_modified_) return Status::OK()
2064 
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2065 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
2066                                      GraphDef* optimized_graph,
2067                                      GraphProperties* properties) {
2068   bool graph_modified_cached = graph_modified_;
2069   graph_modified_ = false;
2070 
2071   RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2072   RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2073       *properties, use_shape_info, optimized_graph, node));
2074   RETURN_IF_MODIFIED(
2075       RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2076   RETURN_IF_ERROR_OR_MODIFIED(
2077       RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2078   RETURN_IF_ERROR_OR_MODIFIED(
2079       SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2080   RETURN_IF_ERROR_OR_MODIFIED(
2081       SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2082   RETURN_IF_ERROR_OR_MODIFIED(
2083       SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2084   RETURN_IF_ERROR_OR_MODIFIED(
2085       SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2086   RETURN_IF_MODIFIED(
2087       SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2088   SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2089   SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2090   SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2091   SET_AND_RETURN_IF_MODIFIED(
2092       SimplifyReduction(optimized_graph, *properties, node));
2093   SET_AND_RETURN_IF_MODIFIED(
2094       SimplifyReshape(*properties, use_shape_info, node));
2095   RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2096       *properties, use_shape_info, optimized_graph, node));
2097   SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2098   SET_AND_RETURN_IF_MODIFIED(
2099       ConstantPushDown(properties, optimized_graph, node));
2100   SET_AND_RETURN_IF_MODIFIED(
2101       MulConvPushDown(optimized_graph, node, *properties));
2102   SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2103   SET_AND_RETURN_IF_MODIFIED(
2104       PartialAssocOpConstFolding(optimized_graph, properties, node));
2105   SET_AND_RETURN_IF_MODIFIED(
2106       MergeConcat(use_shape_info, properties, optimized_graph, node));
2107   SET_AND_RETURN_IF_MODIFIED(
2108       PartialConcatConstFolding(optimized_graph, properties, node));
2109   SET_AND_RETURN_IF_MODIFIED(
2110       ConstantPushDownBiasAdd(properties, optimized_graph, node));
2111   SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2112   SET_AND_RETURN_IF_MODIFIED(
2113       SimplifySelect(*properties, optimized_graph, node));
2114   RETURN_IF_MODIFIED(
2115       RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2116 
2117   graph_modified_ = graph_modified_cached;
2118   return Status::OK();
2119 }
2120 
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2121 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2122                                           GraphDef* optimized_graph,
2123                                           NodeDef* node) {
2124   if (node->attr().count("num_split") == 0) return;
2125   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2126     ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2127   }
2128   if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2129     ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2130   }
2131 }
2132 
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2133 Status ConstantFolding::RemoveShuffleOrTranspose(
2134     const GraphProperties& properties, bool use_shape_info,
2135     GraphDef* optimized_graph, NodeDef* node) {
2136   if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2137     return Status::OK();
2138   Tensor permutation_tensor;
2139   if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2140       properties.HasInputProperties(node->name())) {
2141     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2142     std::vector<int> permutation;
2143     for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2144       if (permutation_tensor.dtype() == DT_INT64) {
2145         permutation.push_back(permutation_tensor.vec<int64>()(j));
2146       } else {
2147         permutation.push_back(permutation_tensor.vec<int>()(j));
2148       }
2149     }
2150     int permutation_size = permutation.size();
2151     if (permutation_size != shape.dim_size()) {
2152       // Number of elements in perm should be same as dim_size. Skip if not.
2153       return Status::OK();
2154     }
2155     // The node is replaceable iff
2156     // dim_size == 0 || all dims have size 1 ||
2157     // all dims with > 1 size are not permuted.
2158     bool replaceable = true;
2159     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2160       replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2161     }
2162     if (replaceable) {
2163       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2164     }
2165   }
2166   return Status::OK();
2167 }
2168 
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2169 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2170                                           bool use_shape_info,
2171                                           GraphDef* optimized_graph,
2172                                           NodeDef* node) {
2173   if (use_shape_info && IsRandomShuffle(*node) &&
2174       !properties.GetInputProperties(node->name()).empty()) {
2175     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2176     // The node is replaceable iff
2177     // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2178     if (!shape.unknown_rank() &&
2179         (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2180       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2181     }
2182   }
2183 }
2184 
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2185 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2186                                       bool use_shape_info,
2187                                       GraphDef* optimized_graph,
2188                                       NodeDef* node) {
2189   if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
2190   Tensor axis;
2191   if (properties.HasInputProperties(node->name()) &&
2192       GetTensorFromConstNode(node->input(1), &axis)) {
2193     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2194     if (shape.unknown_rank()) return Status::OK();
2195     std::set<int> target_axes;
2196     for (int j = 0; j < axis.NumElements(); ++j) {
2197       // value of axis can be negative.
2198       if (axis.dtype() == DT_INT64) {
2199         target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2200                            shape.dim_size());
2201       } else {
2202         target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2203                            shape.dim_size());
2204       }
2205     }
2206 
2207     // The node is replaceable iff
2208     // unknown_rank == false &&
2209     // (dim_size == 0 || all dims have size 1 ||
2210     //  all dims with > 1 size are not in target_axes)
2211     bool replaceable = true;
2212     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2213       replaceable &=
2214           shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2215     }
2216     if (replaceable) {
2217       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2218     }
2219   }
2220   return Status::OK();
2221 }
2222 
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2223 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2224                                       bool use_shape_info,
2225                                       GraphDef* optimized_graph,
2226                                       NodeDef* node) {
2227   if (!use_shape_info || !IsSlice(*node)) return Status::OK();
2228   Tensor begin;
2229   Tensor size;
2230   if (properties.HasInputProperties(node->name()) &&
2231       GetTensorFromConstNode(node->input(1), &begin) &&
2232       GetTensorFromConstNode(node->input(2), &size)) {
2233     const auto& input = properties.GetInputProperties(node->name())[0];
2234     // The node is replaceable iff unknown_rank == false &&
2235     // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2236     bool replaceable = !input.shape().unknown_rank();
2237     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2238       if (begin.dtype() == DT_INT32) {
2239         replaceable &= begin.vec<int>()(j) == 0;
2240       } else {
2241         replaceable &= begin.vec<int64>()(j) == 0;
2242       }
2243       if (size.dtype() == DT_INT32) {
2244         replaceable &= (size.vec<int>()(j) == -1 ||
2245                         size.vec<int>()(j) == input.shape().dim(j).size());
2246       } else {
2247         replaceable &= (size.vec<int64>()(j) == -1 ||
2248                         size.vec<int64>()(j) == input.shape().dim(j).size());
2249       }
2250     }
2251     if (replaceable) {
2252       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2253     }
2254   }
2255   return Status::OK();
2256 }
2257 
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2258 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2259                                              bool use_shape_info,
2260                                              GraphDef* optimized_graph,
2261                                              NodeDef* node) {
2262   if (use_shape_info && IsStridedSlice(*node) &&
2263       properties.GetInputProperties(node->name()).size() == 4) {
2264     TF_RETURN_IF_ERROR(
2265         CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2266     if (node->attr().at("new_axis_mask").i() != 0 ||
2267         node->attr().at("shrink_axis_mask").i() != 0) {
2268       // Skip nodes with new/shrink axis mask, since they involve dimension
2269       // changes.
2270       return Status::OK();
2271     }
2272     const auto& input = properties.GetInputProperties(node->name())[0];
2273     for (int j = 0; j < input.shape().dim_size(); ++j) {
2274       // Skip if input shape is not fully determined.
2275       if (input.shape().dim(j).size() < 0) {
2276         return Status::OK();
2277       }
2278     }
2279 
2280     std::vector<Tensor> input_tensors(3);
2281     for (int i = 1; i < 4; ++i) {
2282       if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2283         return Status::OK();
2284       }
2285     }
2286 
2287     const Tensor& begin = input_tensors[0];
2288     const Tensor& end = input_tensors[1];
2289     const Tensor& strides = input_tensors[2];
2290 
2291     TF_RETURN_IF_ERROR(
2292         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2293     int begin_mask = node->attr().at("begin_mask").i();
2294     int end_mask = node->attr().at("end_mask").i();
2295     std::set<int> expanded_ellipsis_indices;
2296     int ellipsis_index = -1;
2297     for (int j = 0; j < input.shape().dim_size(); ++j) {
2298       // find the ellipsis_mask. If not found, insert one in the end if
2299       // necessary.
2300       if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2301           (ellipsis_index == -1 && j >= strides.NumElements())) {
2302         ellipsis_index = j;
2303       }
2304       // insert the indices that are immediately after ellipsis_index if
2305       // necessary.
2306       if (ellipsis_index != -1 &&
2307           input.shape().dim_size() >
2308               strides.NumElements() + j - ellipsis_index) {
2309         expanded_ellipsis_indices.insert(j);
2310       }
2311     }
2312 
2313     // The node is replaceable iff unknown_rank == false &&
2314     // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2315     //  && strides == 1) for all dimensions.
2316     bool replaceable = !input.shape().unknown_rank();
2317     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2318       if (expanded_ellipsis_indices.find(j) !=
2319           expanded_ellipsis_indices.end()) {
2320         // ellipsis_mask is effective on current dimension.
2321         continue;
2322       }
2323       // when we have ellipsis_mask in between, input.shape().dim_size() will
2324       // be greater than strides.NumElements(), since we will insert
2325       // as many as expanded_ellipsis_indices.size() axes during computation.
2326       // We need to subtract this number from j.
2327       int i = j;
2328       int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2329       if (ellipsis_index != -1 &&
2330           j >= ellipsis_index + expanded_ellipsis_indices_size) {
2331         i = j - expanded_ellipsis_indices_size;
2332       }
2333       int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2334                                         : begin.vec<int64>()(i);
2335       int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2336       int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2337                                           : strides.vec<int64>()(i);
2338       replaceable &= (begin_mask & 1 << i || b == 0) &&
2339                      (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2340                      s == 1;
2341     }
2342     if (replaceable) {
2343       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2344     }
2345   }
2346   return Status::OK();
2347 }
2348 
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2349 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2350                                      bool use_shape_info,
2351                                      GraphDef* optimized_graph, NodeDef* node) {
2352   Tensor multiplies;
2353   if (use_shape_info && IsTile(*node) &&
2354       GetTensorFromConstNode(node->input(1), &multiplies)) {
2355     // The node is replaceable iff all values in multiplies are 1.
2356     bool replaceable = true;
2357     if (multiplies.dtype() == DT_INT32) {
2358       for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2359         replaceable &= multiplies.vec<int>()(j) == 1;
2360       }
2361     } else {
2362       for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
2363         replaceable &= multiplies.vec<int64>()(j) == 1;
2364       }
2365     }
2366     if (replaceable) {
2367       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2368     }
2369   }
2370   return Status::OK();
2371 }
2372 
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2373 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2374                                     bool use_shape_info,
2375                                     GraphDef* optimized_graph, NodeDef* node) {
2376   if (!use_shape_info || !IsPad(*node)) return Status::OK();
2377 
2378   Tensor paddings;
2379   if (GetTensorFromConstNode(node->input(1), &paddings)) {
2380     // The node is replaceable iff all values in paddings are 0.
2381     bool replaceable = true;
2382     if (paddings.dtype() == DT_INT32) {
2383       const auto flatten = paddings.flat<int32>();
2384       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2385         replaceable &= flatten(j) == 0;
2386       }
2387     } else {
2388       const auto flatten = paddings.flat<int64>();
2389       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2390         replaceable &= flatten(j) == 0;
2391       }
2392     }
2393     if (replaceable) {
2394       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2395     }
2396   }
2397   return Status::OK();
2398 }
2399 
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2400 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2401                                       bool use_shape_info,
2402                                       GraphDef* optimized_graph,
2403                                       NodeDef* node) {
2404   if (use_shape_info && IsSqueeze(*node) &&
2405       !properties.GetInputProperties(node->name()).empty()) {
2406     // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2407     // error to squeeze a dimension that is not 1, so we only need to check
2408     // whether the input has > 1 size for each dimension.
2409     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2410     // The node is replaceable iff
2411     // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2412     bool replaceable = !shape.unknown_rank();
2413     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2414       replaceable &= shape.dim(j).size() > 1;
2415     }
2416     if (replaceable) {
2417       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2418     }
2419   }
2420 }
2421 
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2422 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2423   const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2424   if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2425       node_map_->NodeExists(axis_node_name)) {
2426     return false;
2427   }
2428 
2429   // It's unsafe to add a control dependency on the feed node, because it might
2430   // have been never executed otherwiwise.
2431   if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2432     return false;
2433   }
2434 
2435   // Create constant axis node.
2436   Tensor axis_t(DT_INT32, TensorShape({}));
2437   const int axis =
2438       node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2439   NodeDef new_node;
2440   if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2441       !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2442     return false;
2443   }
2444   NodeDef* axis_node = optimized_graph->add_node();
2445   *axis_node = std::move(new_node);
2446   axis_node->set_name(axis_node_name);
2447   node_map_->AddNode(axis_node->name(), axis_node);
2448   // Add a control dependency to make sure axis_node is in the right frame.
2449   const string ctrl_dep = ConstantFolding::AddControlDependency(
2450       node->input(0), optimized_graph, node_map_.get());
2451   axis_node->add_input(ctrl_dep);
2452   axis_node->set_device(node->device());
2453   node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2454   node->set_op("ExpandDims");
2455   if (node->attr().count("axis") != 0) {
2456     node->mutable_attr()->erase("axis");
2457   }
2458   if (node->attr().count("N") != 0) {
2459     node->mutable_attr()->erase("N");
2460   }
2461   (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2462   node->add_input(axis_node->name());
2463   node_map_->AddOutput(axis_node->name(), node->name());
2464   if (node->input_size() > 2) {
2465     node->mutable_input()->SwapElements(1, node->input_size() - 1);
2466   }
2467   return true;
2468 }
2469 
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2470 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2471   if (node->op() != "Case") return false;
2472   const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2473   if (output_idx_node == nullptr ||
2474       !CheckAttrExists(*output_idx_node, "value").ok()) {
2475     return false;
2476   }
2477   Tensor output_idx_t;
2478   if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2479     return false;
2480   int output_idx = output_idx_t.scalar<int>()();
2481   const auto& func_list = node->attr().at("branches").list();
2482   if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2483   NodeDef call_node = *node;
2484   call_node.set_op("PartitionedCall");
2485   call_node.clear_input();
2486   for (int i = 1; i < node->input_size(); ++i) {
2487     call_node.add_input(node->input(i));
2488   }
2489   auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2490   *new_func = func_list.func(output_idx);
2491 
2492   // Move the output shape of the branch to _output_shapes if it is known.
2493   const auto& output_shape_list =
2494       (*node->mutable_attr())["output_shapes"].list();
2495   if (output_shape_list.shape_size() > output_idx) {
2496     TensorShapeProto* new_output_shape =
2497         (*call_node.mutable_attr())["_output_shapes"]
2498             .mutable_list()
2499             ->add_shape();
2500     *new_output_shape =
2501         std::move(node->attr().at("output_shapes").list().shape(output_idx));
2502   }
2503 
2504   call_node.mutable_attr()->erase("output_shapes");
2505   call_node.mutable_attr()->erase("branches");
2506 
2507   *node = std::move(call_node);
2508   return true;
2509 }
2510 
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2511 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2512                                      GraphDef* optimized_graph, NodeDef* node) {
2513   if (!IsSelect(*node)) return false;
2514   const std::vector<OpInfo::TensorProperties>& input_props =
2515       properties.GetInputProperties(node->name());
2516   if (input_props.size() < 3) return false;
2517   const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2518   const bool is_all_true = IsOnes(*predicate_node);
2519   const bool is_all_false = IsZeros(*predicate_node);
2520   if (!is_all_true && !is_all_false) {
2521     return false;
2522   }
2523   const int live_input_idx = is_all_true ? 1 : 2;
2524   const int ignored_input_idx = is_all_true ? 2 : 1;
2525   const TensorShapeProto& predicate_shape = input_props[0].shape();
2526   const bool predicate_is_scalar =
2527       !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2528   if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2529       (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2530        predicate_is_scalar)) {
2531     // Replace node with Identity if no broadcasting is involved.
2532     node->set_op("Identity");
2533     *node->mutable_input(0) =
2534         AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2535     *node->mutable_input(ignored_input_idx) = AddControlDependency(
2536         node->input(ignored_input_idx), optimized_graph, node_map_.get());
2537     node->mutable_input()->SwapElements(0, live_input_idx);
2538   } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2539                                               optimized_graph)) {
2540     return false;
2541   }
2542   DedupControlInputs(node);
2543   return true;
2544 }
2545 
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2546 void ConstantFolding::RemoveRedundantVariableUpdates(
2547     GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2548   static const absl::flat_hash_set<string>* kVariableReadOps =
2549       new absl::flat_hash_set<string>{"AssignAddVariableOp",
2550                                       "AssignSubVariableOp",
2551                                       "AssignAdd",
2552                                       "AssignSub",
2553                                       "ScatterAdd",
2554                                       "ScatterSub",
2555                                       "ScatterMul",
2556                                       "ScatterDiv",
2557                                       "ScatterNdAdd",
2558                                       "ScatterNdSub",
2559                                       "ScatterNdMul",
2560                                       "ScatterNdDiv",
2561                                       "ResourceScatterAdd",
2562                                       "ResourceScatterSub",
2563                                       "ResourceScatterMul",
2564                                       "ResourceScatterDiv",
2565                                       "ResourceScatterNdAdd",
2566                                       "ResourceScatterNdSub",
2567                                       "ResourceScatterNdMul",
2568                                       "ResourceScatterNdDiv"};
2569   if (kVariableReadOps == nullptr ||
2570       kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2571     return;
2572   const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2573   const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2574   if (delta_node == nullptr) return;
2575   const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2576                              absl::StrContains(node->op(), "Sub");
2577   if ((is_add_or_sub && IsZeros(*delta_node)) ||
2578       (!is_add_or_sub && IsOnes(*delta_node))) {
2579     VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2580     if (absl::StrContains(node->op(), "Variable") ||
2581         absl::StrContains(node->op(), "Resource")) {
2582       ReplaceOperationWithNoOp(node, properties, optimized_graph);
2583     } else {
2584       ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2585                                    optimized_graph);
2586     }
2587   }
2588 }
2589 
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2590 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2591                                              NodeDef* node) {
2592   if (!IsEnter(*node) || node->input_size() == 0 ||
2593       node->attr().count("is_constant") == 0 ||
2594       !node->attr().at("is_constant").b()) {
2595     return false;
2596   }
2597   const string& node_name = node->name();
2598   const NodeDef* input = node_map_->GetNode(node->input(0));
2599   if (input == nullptr || !IsReallyConstant(*input) ||
2600       OptimizedNodeExists(*input, "_enter")) {
2601     return false;
2602   }
2603   // Find non-constant nodes that consume the output of *node.
2604   std::vector<NodeDef*> consumers;
2605   for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2606     if (!IsConstant(*fanout)) {
2607       for (int i = 0; i < fanout->input_size(); ++i) {
2608         if (fanout->input(i) == node_name) {
2609           consumers.push_back(const_cast<NodeDef*>(fanout));
2610           break;
2611         }
2612       }
2613     }
2614   }
2615   if (consumers.empty()) {
2616     return false;
2617   }
2618   graph_modified_ = true;
2619   NodeDef* new_node = optimized_graph->add_node();
2620   *new_node = *input;
2621   new_node->set_name(OptimizedNodeName(*input, "_enter"));
2622   new_node->set_device(node->device());
2623   new_node->clear_input();
2624   new_node->add_input(AsControlDependency(node_name));
2625   node_map_->AddNode(new_node->name(), new_node);
2626   node_map_->AddOutput(node_name, new_node->name());
2627   for (NodeDef* consumer : consumers) {
2628     for (int i = 0; i < consumer->input_size(); ++i) {
2629       if (NodeName(consumer->input(i)) == node_name) {
2630         node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2631         consumer->set_input(i, new_node->name());
2632       }
2633     }
2634   }
2635   return true;
2636 }
2637 
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2638 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2639   if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2640       !OptimizedNodeExists(*node, "_const_false") &&
2641       !OptimizedNodeExists(*node, "_const_true")) {
2642     bool already_optimized = true;
2643     // If the optimization was already applied, the switch would have exactly
2644     // one Identity node consuming each of its outputs, each without any
2645     // non-control outputs.
2646     const auto& fanouts = node_map_->GetOutputs(node->name());
2647     if (fanouts.size() == 2) {
2648       for (const NodeDef* fanout : fanouts) {
2649         if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2650             HasRegularOutputs(*fanout, *node_map_)) {
2651           already_optimized = false;
2652           break;
2653         }
2654       }
2655     }
2656     Tensor false_t(DT_BOOL, TensorShape({}));
2657     Tensor true_t(DT_BOOL, TensorShape({}));
2658     // Make sure we don't proceed if this switch node was already optimized.
2659     if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2660         SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2661       // Copy the set of consumers of the switch as they will be manipulated
2662       // below.
2663       std::vector<NodeDef*> consumers =
2664           node_map_->GetOutputsOrderedByNodeName(node->name());
2665       // Create constant false & true nodes.
2666       NodeDef tmp_false_node;
2667       tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2668       if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2669                          &tmp_false_node)
2670                .ok()) {
2671         return false;
2672       }
2673       tmp_false_node.set_device(node->device());
2674       NodeDef tmp_true_node;
2675       tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2676       if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2677                          &tmp_true_node)
2678                .ok()) {
2679         return false;
2680       }
2681       tmp_true_node.set_device(node->device());
2682 
2683       // Add const nodes to graph.
2684       NodeDef* false_node = optimized_graph->add_node();
2685       false_node->Swap(&tmp_false_node);
2686       NodeDef* true_node = optimized_graph->add_node();
2687       true_node->Swap(&tmp_true_node);
2688 
2689       // Add controls from the switch ports to the constants, and connect the
2690       // constants to the original switch outputs.
2691       const string false_port = node->name();
2692       const string true_port = strings::StrCat(node->name(), ":1");
2693       const string false_ctrl_dep =
2694           AddControlDependency(false_port, optimized_graph, node_map_.get());
2695       false_node->add_input(false_ctrl_dep);
2696       const string true_ctrl_dep =
2697           AddControlDependency(true_port, optimized_graph, node_map_.get());
2698       true_node->add_input(true_ctrl_dep);
2699 
2700       node_map_->AddNode(false_node->name(), false_node);
2701       node_map_->AddNode(true_node->name(), true_node);
2702       node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2703       node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2704 
2705       for (NodeDef* consumer : consumers) {
2706         for (int i = 0; i < consumer->input_size(); ++i) {
2707           const string& input = consumer->input(i);
2708           if (input == false_port) {
2709             consumer->set_input(i, false_node->name());
2710             node_map_->UpdateInput(consumer->name(), false_port,
2711                                    false_node->name());
2712           } else if (input == true_port) {
2713             consumer->set_input(i, true_node->name());
2714             node_map_->UpdateInput(consumer->name(), true_port,
2715                                    true_node->name());
2716           }
2717         }
2718       }
2719       return true;
2720     }
2721   }
2722   return false;
2723 }
2724 
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2725 bool ConstantFolding::IsReductionWithConstantIndices(
2726     const NodeDef& node, bool* indices_is_empty) const {
2727   // Ensure its an appropriate Reduce node.
2728   if (!IsReduction(node) || node.input_size() < 2) {
2729     return false;
2730   }
2731   // Ensure that the axes to reduce by are constant.
2732   NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2733   if (!IsReallyConstant(*reductions_indices) ||
2734       !reductions_indices->attr().count("value")) {
2735     return false;
2736   }
2737   const TensorShapeProto& reduction_indices_shape =
2738       reductions_indices->attr().at("value").tensor().tensor_shape();
2739   *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2740   return true;
2741 }
2742 
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2743 bool ConstantFolding::IsReductionCandidateForSimplification(
2744     const NodeDef& node, const GraphProperties& properties,
2745     TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2746     bool* is_single_element_op) const {
2747   // Get the properties of the input & output tensors and check if they both
2748   // contain a single element.
2749   if (!properties.HasInputProperties(node.name()) ||
2750       !properties.HasOutputProperties(node.name())) {
2751     return false;
2752   }
2753   const auto& input_props = properties.GetInputProperties(node.name())[0];
2754   const auto& output_props = properties.GetOutputProperties(node.name())[0];
2755   if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2756       !output_props.has_shape() || output_props.shape().unknown_rank()) {
2757     return false;
2758   }
2759   *input_tensor_shape = input_props.shape();
2760   *output_tensor_shape = output_props.shape();
2761   for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2762     if (input_tensor_shape->dim(i).size() < 0) {
2763       return false;
2764     }
2765   }
2766   for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2767     if (output_tensor_shape->dim(i).size() < 0) {
2768       return false;
2769     }
2770   }
2771   const int input_num_elements =
2772       TensorShape(*input_tensor_shape).num_elements();
2773   const int output_num_elements =
2774       TensorShape(*output_tensor_shape).num_elements();
2775   *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2776 
2777   return true;
2778 }
2779 
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2780 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2781     const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2782     const TensorVector& reduction_indices_vector) const {
2783   int output_size = reduction_indices_vector[0]->NumElements();
2784   if (output_size == 0) {
2785     return true;
2786   }
2787 
2788   if (!keep_dims) {
2789     return false;
2790   }
2791   bool simplifiable = true;
2792   for (int i = 0; i < output_size; ++i) {
2793     int64 dim;
2794     if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2795       dim = reduction_indices_vector[0]->flat<int32>()(i);
2796     } else {
2797       dim = reduction_indices_vector[0]->flat<int64>()(i);
2798     }
2799     if (dim < 0) {
2800       dim += input_shape.dim_size();
2801     }
2802     if (dim < 0 || dim >= input_shape.dim_size() ||
2803         input_shape.dim(dim).size() != 1) {
2804       simplifiable = false;
2805       break;
2806     }
2807   }
2808   return simplifiable;
2809 }
2810 
ReplaceReductionWithIdentity(NodeDef * node) const2811 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2812   // Replace the reduction node with an identity node, that can be further
2813   // optimized by other passes.
2814   DataType output_type;
2815   if (node->attr().count("T") != 0) {
2816     output_type = node->attr().at("T").type();
2817   } else if (IsAny(*node) || IsAll(*node)) {
2818     output_type = DT_BOOL;
2819   } else {
2820     return false;
2821   }
2822   node->set_op("Identity");
2823   EraseRegularNodeAttributes(node);
2824   (*node->mutable_attr())["T"].set_type(output_type);
2825   *node->mutable_input(1) = AsControlDependency(node->input(1));
2826   return true;
2827 }
2828 
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2829 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2830                                         const GraphProperties& properties,
2831                                         NodeDef* node) {
2832   bool indices_is_empty = false;
2833   if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2834     return false;
2835   }
2836   if (indices_is_empty) {
2837     return ReplaceReductionWithIdentity(node);
2838   }
2839   bool is_single_element_op = false;
2840   TensorShapeProto input_tensor_shape, output_tensor_shape;
2841   if (!IsReductionCandidateForSimplification(
2842           *node, properties, &input_tensor_shape, &output_tensor_shape,
2843           &is_single_element_op)) {
2844     return false;
2845   }
2846 
2847   // Get the reduction indices.
2848   string reduction_indices_input = node->input(1);
2849   NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2850   TensorVector reduction_indices_vector;
2851   auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2852     for (const auto& out : reduction_indices_vector) {
2853       delete out.tensor;
2854     }
2855   });
2856   if (!EvaluateNode(*reduction_indices, TensorVector(),
2857                     &reduction_indices_vector)
2858            .ok() ||
2859       reduction_indices_vector.size() != 1) {
2860     return false;
2861   }
2862 
2863   bool keep_dims =
2864       node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2865   bool simplifiable_to_reshape =
2866       is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2867   bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2868       *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2869 
2870   if (simplifiable_to_reshape) {
2871     // Const node to output shape.
2872     const int new_num_dimensions = output_tensor_shape.dim_size();
2873     Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2874     for (int i = 0; i < new_num_dimensions; i++) {
2875       tensor.flat<int>()(i) = 1;
2876     }
2877     TensorValue shape_value(&tensor);
2878     NodeDef* shape_node = optimized_graph->add_node();
2879     if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2880                        shape_node)
2881              .ok()) {
2882       return false;
2883     }
2884     shape_node->set_device(node->device());
2885     node_map_->AddNode(shape_node->name(), shape_node);
2886     // Control dependency to ensure shape_node is in the correct frame.
2887     shape_node->add_input(AsControlDependency(reduction_indices_input));
2888     node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2889     // Optimize node to Reshape.
2890     node->set_op("Reshape");
2891     node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2892     node->set_input(1, shape_node->name());
2893     node->mutable_attr()->erase("keep_dims");
2894     node->mutable_attr()->erase("Tidx");
2895     AttrValue attr_type_indices;
2896     attr_type_indices.set_type(DT_INT32);
2897     (*node->mutable_attr())["Tshape"] = attr_type_indices;
2898     return true;
2899   } else if (simplifiable_to_identity) {
2900     return ReplaceReductionWithIdentity(node);
2901   }
2902   return false;
2903 }
2904 
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2905 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2906                                       bool use_shape_info, NodeDef* node) {
2907   if (!use_shape_info || node->attr().count("T") == 0 ||
2908       !IsSimplifiableReshape(*node, properties)) {
2909     return false;
2910   }
2911   DataType output_type = node->attr().at("T").type();
2912   node->set_op("Identity");
2913   EraseRegularNodeAttributes(node);
2914   (*node->mutable_attr())["T"].set_type(output_type);
2915   *node->mutable_input(1) = AsControlDependency(node->input(1));
2916   return true;
2917 }
2918 
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2919 Status ConstantFolding::SimplifyArithmeticOperations(
2920     const GraphProperties& properties, bool use_shape_info,
2921     GraphDef* optimized_graph, NodeDef* node) {
2922   const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2923   const bool is_matmul = IsAnyMatMul(*node);
2924   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2925   const bool is_sub = IsSub(*node);
2926   const bool is_any_div = IsAnyDiv(*node);
2927   // Simplify arithmetic operations with ones or zeros.
2928   if (use_shape_info &&
2929       (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2930       properties.HasInputProperties(node->name()) &&
2931       properties.HasOutputProperties(node->name())) {
2932     const NodeDef* x = node_map_->GetNode(node->input(0));
2933     const NodeDef* y = node_map_->GetNode(node->input(1));
2934     if (x == nullptr || y == nullptr) {
2935       return errors::InvalidArgument("Invalid inputs to node: ",
2936                                      node->DebugString());
2937     }
2938     const TensorShapeProto& output_shape =
2939         properties.GetOutputProperties(node->name())[0].shape();
2940 
2941     // Simplify element-wise multiplication by ones or addition/subtraction
2942     // of zeros.
2943     const TensorShapeProto& y_shape =
2944         properties.GetInputProperties(node->name())[1].shape();
2945     const TensorShapeProto& x_shape =
2946         properties.GetInputProperties(node->name())[0].shape();
2947     const bool y_matches_output_shape =
2948         ShapesSymbolicallyEqual(output_shape, y_shape);
2949     const bool x_matches_output_shape =
2950         ShapesSymbolicallyEqual(output_shape, x_shape);
2951 
2952     const bool x_is_zero = IsZeros(*x);
2953     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2954     if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2955       // 1 * y = y or 0 + y = y.
2956       if (y_matches_output_shape) {
2957         ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2958       } else if (x_matches_output_shape) {
2959         ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
2960                                               optimized_graph);
2961       }
2962       return Status::OK();
2963     }
2964 
2965     if (y_matches_output_shape && (is_sub && x_is_zero)) {
2966       // Replace 0 - y with Neg(y).
2967       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2968       return Status::OK();
2969     }
2970 
2971     // Replace 1 / y with Reciprocal op.
2972     if (y_matches_output_shape && is_any_div && x_is_one) {
2973       TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2974       DataType type = node->attr().at("T").type();
2975       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2976         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2977         return Status::OK();
2978       }
2979     }
2980 
2981     const bool y_is_zero = IsZeros(*y);
2982     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2983     if (((is_mul || is_any_div) && y_is_one) ||
2984         ((is_add || is_sub) && y_is_zero)) {
2985       // x * 1 = x or x / 1 = x or x +/- 0 = x
2986       if (x_matches_output_shape) {
2987         ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
2988       } else if (y_matches_output_shape) {
2989         ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
2990                                               optimized_graph);
2991       }
2992       return Status::OK();
2993     }
2994 
2995     // x OR true = true OR y = true.
2996     const PartialTensorShape shp(output_shape);
2997     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
2998       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
2999           1, properties, output_shape, node, optimized_graph));
3000       return Status::OK();
3001     }
3002 
3003     // Simplify multiplication and matmul by zeros.
3004     // Also optimize zeros divided by a tensor, but only if we are in
3005     // aggressive mode, since we might get rid of divisions by zero.
3006     const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3007     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3008     if ((x_is_zero || y_is_zero) &&
3009         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3010       if (shp.IsFullyDefined()) {
3011         bool is_quantized = IsQuantizedMatMul(*node);
3012         TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3013             0, properties, output_shape, node, optimized_graph));
3014         if (is_quantized && graph_modified_) {
3015           TF_RETURN_IF_ERROR(
3016               AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3017         }
3018         return Status::OK();
3019       }
3020       // Even if an input shape is only partially known, we may known that it
3021       // matches the output shape and thus forward or broadcast the
3022       // corresponding zero input.
3023       if ((is_mul || is_any_div) && x_is_zero) {
3024         if (x_matches_output_shape) {
3025           ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3026         } else if (y_matches_output_shape) {
3027           ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3028                                                 optimized_graph);
3029         }
3030         return Status::OK();
3031       } else if (is_mul && y_is_zero) {
3032         if (y_matches_output_shape) {
3033           ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3034         } else if (x_matches_output_shape) {
3035           ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3036                                                 optimized_graph);
3037         }
3038         return Status::OK();
3039       }
3040     }
3041   }
3042   return Status::OK();
3043 }
3044 
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3045 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3046                                                NodeDef* node) {
3047   // Strength reduce floating point division by a constant Div(x, const) to
3048   // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3049   // will be constant folded to Mul(x, 1.0/const).
3050   if (node->input_size() >= 2 &&
3051       (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3052     const string& const_input = node->input(1);
3053     const NodeDef* denom = node_map_->GetNode(const_input);
3054     CHECK(denom != nullptr);
3055     if (!IsReallyConstant(*denom)) {
3056       return false;
3057     }
3058     if (node->attr().count("T") == 0) {
3059       return false;
3060     }
3061     DataType type = node->attr().at("T").type();
3062     // Skip integer division.
3063     if (IsDiv(*node) &&
3064         !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3065       return false;
3066     }
3067     // Insert new reciprocal op and change node from Div to Mul.
3068     NodeDef* reciprocal_node = optimized_graph->add_node();
3069     reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3070     reciprocal_node->set_op("Reciprocal");
3071     reciprocal_node->set_device(node->device());
3072     reciprocal_node->add_input(const_input);
3073     (*reciprocal_node->mutable_attr())["T"].set_type(type);
3074 
3075     // Re-wire inputs and outputs.
3076     if (IsXdivy(*node)) {
3077       node->set_op("MulNoNan");
3078       node->set_input(1, node->input(0));
3079       node->set_input(0, reciprocal_node->name());
3080     } else {
3081       node->set_op("Mul");
3082       node->set_input(1, reciprocal_node->name());
3083     }
3084     node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3085     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3086 
3087     return true;
3088   }
3089 
3090   return false;
3091 }
3092 
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3093 bool ConstantFolding::PrepareConstantPushDown(
3094     const NodeDef& parent, const GraphProperties& properties,
3095     bool must_have_properties, ConstantPushDownContext* ctx) const {
3096   if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3097     return false;
3098   }
3099   NodeDef* left_child = node_map_->GetNode(parent.input(0));
3100   NodeDef* right_child = node_map_->GetNode(parent.input(1));
3101   ctx->left_child_is_const = IsReallyConstant(*left_child);
3102   ctx->right_child_is_const = IsReallyConstant(*right_child);
3103   ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3104   ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3105 
3106   // Nothing to do unless the parent has a constant child node.
3107   if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3108     return false;
3109   }
3110 
3111   // Don't move nodes across devices.
3112   if (parent.device() != ctx->op_child->device() ||
3113       parent.device() != ctx->const_child->device()) {
3114     return false;
3115   }
3116 
3117   // Make sure that it is safe to change the value of the child node result.
3118   if (ctx->op_child->input_size() < 2 ||
3119       nodes_to_preserve_.find(ctx->op_child->name()) !=
3120           nodes_to_preserve_.end() ||
3121       NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3122     return false;
3123   }
3124 
3125   // Don't apply reassociation to floating point types of low precision.
3126   // The danger of significant numerical changes is too high.
3127   if (!CheckAttrExists(parent, "T").ok()) return false;
3128   DataType dtype = parent.attr().at("T").type();
3129   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3130     return false;
3131   }
3132 
3133   // Don't rewrite the tree if it might create cycles.
3134   // TODO(rmlarsen): Add back handling of control dependency from op to C.
3135   const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3136   if (child_output.find(ctx->const_child) != child_output.end()) {
3137     return false;
3138   }
3139 
3140   // Get leaf nodes.
3141   ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3142   ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3143   ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3144   ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3145 
3146   if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3147     // Child is already foldable, leave it alone.
3148     return false;
3149   }
3150 
3151   // Don't move nodes across devices.
3152   if (parent.device() != ctx->left_leaf->device() ||
3153       parent.device() != ctx->right_leaf->device()) {
3154     return false;
3155   }
3156 
3157   // Get shape and type information.
3158   ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3159   ctx->op_child_input_props =
3160       &properties.GetInputProperties(ctx->op_child->name());
3161   if (must_have_properties && (ctx->parent_input_props == nullptr ||
3162                                ctx->parent_input_props->size() < 2 ||
3163                                ctx->op_child_input_props == nullptr ||
3164                                ctx->op_child_input_props->size() < 2)) {
3165     return false;
3166   }
3167 
3168   VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3169           << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3170           << ")";
3171 
3172   return true;
3173 }
3174 
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3175 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3176                                               GraphDef* optimized_graph,
3177                                               NodeDef* node) {
3178   // This implements constant push-down for BiasAdd. In the following "CV" is a
3179   // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3180   // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3181   // non-constant matrix, and "BA" is BiasAdd.
3182   // For a valid input graph, the following 4 rewrites are legal:
3183   //
3184   //  1)                  +                +
3185   //                     / \              / \
3186   //                    BA  CV    -- >   BA  V
3187   //                   / \              / \
3188   //                  M   V            M   CV
3189   //
3190   //  2)                  +                +
3191   //                     / \              / \
3192   //                    BA  CM    -- >   BA  M
3193   //                   / \              / \
3194   //                  M   V            CM  V
3195   //
3196   //  3)                  BA               BA
3197   //                     / \              / \
3198   //                    +  CV     -- >   +   V
3199   //                   / \              / \
3200   //                  M   V            M  CV
3201   //
3202   //  4)                  BA               BA      = parent
3203   //                     / \              / \
3204   //                    BA  CV    -- >   BA  V     = children
3205   //                   / \              / \
3206   //                  M   V            M  CV       = leaves
3207   //
3208   // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3209 
3210   const bool parent_is_bias_add = IsBiasAdd(*node);
3211   if (!parent_is_bias_add && !IsAdd(*node)) return false;
3212   ConstantPushDownContext ctx;
3213   if (!PrepareConstantPushDown(*node, *properties,
3214                                /*must_have_properties=*/true, &ctx)) {
3215     return false;
3216   }
3217   // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3218   // >= 2 and the leaves must be vectors, we cannot swap them.
3219   if (ctx.left_child_is_const && parent_is_bias_add) return false;
3220   const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3221   if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3222 
3223   // Get properties to validate rank and dtype constraints.
3224   if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3225       (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3226       (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3227       (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3228       (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3229     return false;
3230   }
3231 
3232   // Now get the ranks and types of the 3 leaf nodes.
3233   const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3234   const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3235   // At least one leaf must be a vector.
3236   if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3237   const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3238   const int matrix_idx = 1 - vector_idx;
3239 
3240   const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3241   const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3242   if (vector_rank != 1) return false;  // this should never happen.
3243   const DataType vector_type = vector_prop.dtype();
3244 
3245   const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3246   const int matrix_rank = matrix_prop.shape().dim_size();
3247   const DataType matrix_type = matrix_prop.dtype();
3248 
3249   const int const_idx = ctx.left_child_is_const ? 0 : 1;
3250   const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3251   const int const_rank = const_prop.shape().dim_size();
3252   const DataType const_type = const_prop.dtype();
3253 
3254   int input_to_swap = -1;
3255 
3256   if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3257       const_type == matrix_type) {
3258     // Case 2:
3259     input_to_swap = matrix_idx;
3260   } else if (const_rank == 1 && const_type == vector_type) {
3261     // Case 1, 3, and, 4:
3262     input_to_swap = vector_idx;
3263   }
3264   if (input_to_swap == -1) return false;
3265   const NodeDef* leaf_to_swap =
3266       node_map_->GetNode(ctx.op_child->input(input_to_swap));
3267   if (IsConstant(*leaf_to_swap)) return false;
3268 
3269   node_map_->UpdateInput(node->name(), node->input(const_idx),
3270                          ctx.op_child->input(input_to_swap));
3271   node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3272   if (ctx.op_child->input(input_to_swap) !=
3273       ctx.op_child->input(1 - input_to_swap)) {
3274     node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3275                             ctx.op_child->name());
3276   }
3277   std::swap(*node->mutable_input(const_idx),
3278             *ctx.op_child->mutable_input(input_to_swap));
3279   properties->ClearInputProperties(node->name());
3280   properties->ClearInputProperties(ctx.op_child->name());
3281 
3282   return true;
3283 }
3284 
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3285 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3286                                        GraphDef* optimized_graph,
3287                                        NodeDef* node) {
3288   // Consider the transformation
3289   //
3290   //                      +                +       = parent
3291   //                     / \              / \
3292   //                    C   +    -- >    X   +     = children
3293   //                       / \              / \
3294   //                      X   Y            C   Y   = leaves
3295   //
3296   // where C is constant, X is non-constant, Y may be constant or non-constant,
3297   // and '+' denotes an associative and commutative operator like addition or
3298   // multiplication. This optimization pushes constants down in the tree to
3299   // canonicalize it. Moreover, in cases where the child node has a second
3300   // constant input Y we will create a leaf node that can be folded, e.g.
3301   //
3302   //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3303   //
3304   // We also handle the non-commutative cases of subtraction and division
3305   // by rotating the tree locally, e.g.
3306   //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3307   //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3308 
3309   // Get parent op type.
3310   const bool is_add = IsAdd(*node);
3311   const bool is_mul = IsMul(*node);
3312   const bool is_sub = IsSub(*node);
3313   const bool is_div = IsDiv(*node);
3314   if (!(is_add || is_sub || is_mul || is_div)) return false;
3315   const bool is_symmetric = is_add || is_mul;
3316 
3317   ConstantPushDownContext ctx;
3318   if (!PrepareConstantPushDown(*node, *properties,
3319                                /*must_have_properties=*/false, &ctx)) {
3320     return false;
3321   }
3322 
3323   // Get child op type.
3324   const bool is_child_add = IsAdd(*ctx.op_child);
3325   const bool is_child_mul = IsMul(*ctx.op_child);
3326   const bool is_child_sub = IsSub(*ctx.op_child);
3327   const bool is_child_div = IsDiv(*ctx.op_child);
3328   const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3329   const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3330   if (!is_add_sub && !is_mul_div) {
3331     return false;
3332   }
3333   const bool is_child_symmetric = is_child_add || is_child_mul;
3334 
3335   if (!CheckAttrExists(*node, "T").ok()) return false;
3336   DataType dtype = node->attr().at("T").type();
3337   if (!(is_symmetric && is_child_symmetric) &&
3338       !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3339     return false;
3340   }
3341 
3342   const NodeDef* y_node =
3343       ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3344   if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3345       !ctx.op_child_input_props->empty()) {
3346     // If we know the shapes of the nodes being swapped, make sure we don't push
3347     // down a larger node and create more work by broadcasting earlier in the
3348     // expressions tree.
3349     const PartialTensorShape c_shape(
3350         (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3351     const PartialTensorShape x_shape(
3352         (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3353 
3354     if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3355         c_shape.num_elements() > x_shape.num_elements()) {
3356       return false;
3357     } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3358                c_shape.dims() > 0) {
3359       for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3360         if (x_shape.dim_size(idx) >= 0 &&
3361             c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3362           return false;
3363         }
3364       }
3365     }
3366   }
3367 
3368   // Get the node names corresponding to X, Y, and C.
3369   const string input_x =
3370       ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3371   const string input_y = input_x == ctx.op_child->input(0)
3372                              ? ctx.op_child->input(1)
3373                              : ctx.op_child->input(0);
3374   const string input_c =
3375       ctx.left_child_is_const ? node->input(0) : node->input(1);
3376   const string input_op =
3377       ctx.left_child_is_const ? node->input(1) : node->input(0);
3378   VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3379 
3380   // Now we have identified the nodes to swap, update the nodemap accordingly.
3381   node_map_->UpdateInput(node->name(), input_c, input_x);
3382   node_map_->AddOutput(input_c, ctx.op_child->name());
3383   if (input_x != input_y) {
3384     node_map_->RemoveOutput(input_x, ctx.op_child->name());
3385   }
3386   properties->ClearInputProperties(node->name());
3387   properties->ClearInputProperties(ctx.op_child->name());
3388 
3389   if (is_symmetric && is_child_symmetric) {
3390     // Easy case (only commutative ops). We always write this as one of
3391     //   +
3392     //  / \
3393     // X   +
3394     //    / \
3395     //   C   Y
3396     node->set_input(0, input_x);
3397     node->set_input(1, input_op);
3398     ctx.op_child->set_input(0, input_c);
3399     ctx.op_child->set_input(1, input_y);
3400   } else {
3401     // More complicated case: When there are non-commutative operations like
3402     // subtractions or divisions involved, we may have to rotate the tree
3403     // and/or change op types. There are 6 non-trivial cases depending on
3404     // the effective generalized "sign" of each of the three terms C, Y, and X.
3405     // Here are the final trees we want to generate for those 6 cases:
3406     //
3407     // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
3408     //
3409     //                 -        -        -      -       +        +
3410     //                / \      / \      / \    / \     / \      / \
3411     //               +   X    -   X    -   X  X   +   X   -    X   -
3412     //              / \      / \      / \        / \     / \      / \
3413     //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
3414     //
3415 
3416     // First, let's determine the effective sign of each term in the original
3417     // expression
3418     auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3419       bool leaf_negated = !is_child_symmetric && is_right_leaf;
3420       bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3421       return leaf_negated != child_negated;
3422     };
3423     const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3424     const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3425     bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3426     bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3427     bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3428     // Rewrite the parent node.
3429     node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3430     node->set_input(0, neg_x ? input_op : input_x);
3431     node->set_input(1, neg_x ? input_x : input_op);
3432     // Rewrite the child node.
3433     ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3434     ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3435     ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3436   }
3437   return true;
3438 }
3439 
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3440 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3441                                       const GraphProperties& properties) {
3442   // Push down multiplication on ConvND.
3443   //                       *                  ConvND
3444   //                     /   \                /    \
3445   //                 ConvND  C2    -- >      X      *
3446   //                  / \                          / \
3447   //                 X  C1                       C1  C2
3448   //
3449   // where C1 and C2 are constants and X is non-constant.
3450   //
3451   // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3452 
3453   if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3454 
3455   NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3456   NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3457   // One child must be constant, and the second must be Conv op.
3458   const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3459   const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3460   if (!left_child_is_constant && !right_child_is_constant) {
3461     return false;
3462   }
3463   NodeDef* conv_node =
3464       left_child_is_constant ? mul_right_child : mul_left_child;
3465   if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3466     return false;
3467   }
3468   if (node->device() != mul_left_child->device() ||
3469       node->device() != mul_right_child->device()) {
3470     return false;
3471   }
3472 
3473   // Make sure that it is safe to change the value of the convolution
3474   // output.
3475   if (conv_node->input_size() < 2 ||
3476       NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3477       nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3478     return false;
3479   }
3480 
3481   // Identify the nodes to swap.
3482   NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3483   NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3484   const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3485   const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3486   if (!conv_left_is_constant && !conv_right_is_constant) {
3487     // At least one of the convolution inputs should be constant.
3488     return false;
3489   }
3490   if (conv_left_is_constant && conv_right_is_constant) {
3491     // Leverage regular constant folding to handle this.
3492     return false;
3493   }
3494   const auto& mul_props = properties.GetOutputProperties(node->name());
3495   const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3496   if (mul_props.empty() || conv_props.empty()) {
3497     return false;
3498   }
3499   const auto& mul_shape = mul_props[0].shape();
3500   const auto& conv_shape = conv_props[0].shape();
3501   if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3502     return false;
3503   }
3504 
3505   const auto& input_props = properties.GetInputProperties(conv_node->name());
3506   if (input_props.size() < 2) {
3507     return false;
3508   }
3509   const auto& filter_shape = input_props[1].shape();
3510 
3511   NodeDef* const_node =
3512       left_child_is_constant ? mul_left_child : mul_right_child;
3513   const auto& const_props = properties.GetOutputProperties(const_node->name());
3514   if (const_props.empty()) {
3515     return false;
3516   }
3517   const auto& const_shape = const_props[0].shape();
3518   if (!IsValidConstShapeForMulConvPushDown(
3519           conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3520     return false;
3521   }
3522 
3523   string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3524   if (node_map_->NodeExists(mul_new_name)) {
3525     return false;
3526   }
3527   // Make sure we don't introduce loops in the graph by removing control
3528   // dependencies from the conv2d node to c2.
3529   string conv_const_input =
3530       conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3531   if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3532                               node_map_.get())) {
3533     // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3534     MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3535                          node_map_.get());
3536   }
3537 
3538   conv_node->set_name(node->name());
3539   node->set_name(mul_new_name);
3540   if (conv_left_is_constant) {
3541     node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3542     conv_node->set_input(0, mul_new_name);
3543   } else {
3544     node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3545     conv_node->set_input(1, mul_new_name);
3546   }
3547   NodeDef* conv_const_node =
3548       conv_left_is_constant ? conv_left_child : conv_right_child;
3549   if (left_child_is_constant) {
3550     node->set_input(1, conv_const_node->name());
3551   } else {
3552     node->set_input(0, conv_const_node->name());
3553   }
3554   node_map_->AddNode(mul_new_name, node);
3555 
3556   return true;
3557 }
3558 
PartialConstPropThroughIdentityN(NodeDef * node)3559 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3560   // Partial constant propagation through IdentityN.
3561   if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3562       !HasRegularInputs(*node))
3563     return false;
3564 
3565   std::vector<int> inputs_to_forward;
3566   for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3567     const string& input = node->input(input_idx);
3568     if (IsControlInput(input)) {
3569       return false;
3570     }
3571     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3572     if (input_node == nullptr) {
3573       LOG(ERROR) << "Bad input: " << input;
3574       return false;
3575     }
3576     // Forward constant inputs to outputs and add a control dependency on
3577     // the IdentityN node.
3578     if (IsReallyConstant(*input_node)) {
3579       inputs_to_forward.push_back(input_idx);
3580     }
3581   }
3582   return ForwardInputs(node, inputs_to_forward);
3583 }
3584 
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3585 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3586                                                  GraphProperties* properties,
3587                                                  NodeDef* node) {
3588   // Partial constant folding for associative operators:
3589   // Split AddN/AccumulateNV2 to enable partial
3590   // folding of ops when more than one but not all inputs are constant.
3591   // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3592   // addition is commutative.
3593   if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3594 
3595   const int num_non_control_inputs = NumNonControlInputs(*node);
3596   if (num_non_control_inputs <= 2) return false;
3597   const int num_control_inputs = node->input_size() - num_non_control_inputs;
3598   std::vector<int> const_inputs;
3599   std::vector<int> nonconst_inputs;
3600   for (int i = 0; i < node->input_size(); ++i) {
3601     const string& input = node->input(i);
3602     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3603     if (input_node == nullptr) return false;
3604     if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3605       const_inputs.push_back(i);
3606     } else {
3607       // Non-const and control inputs.
3608       nonconst_inputs.push_back(i);
3609     }
3610   }
3611   // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3612   // a fake node that cannot be constant folded by itself.
3613   int const_inputs_size = const_inputs.size();
3614   if (const_inputs_size == num_non_control_inputs &&
3615       node->op() == "AccumulateNV2") {
3616     node->set_op("AddN");
3617     node->mutable_attr()->erase("shape");
3618     return true;
3619   }
3620   const string new_node_name = OptimizedNodeName(
3621       *node, strings::StrCat("_partial_split_", const_inputs_size));
3622   if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3623       !node_map_->NodeExists(new_node_name)) {
3624     NodeDef* added_node = optimized_graph->add_node();
3625     *added_node = *node;
3626     // Always use AddN for the constant node, since AccumulateNV2 is a fake
3627     // node that cannot be constant folded, since it does not have a kernel.
3628     added_node->set_op("AddN");
3629     added_node->mutable_attr()->erase("shape");
3630     added_node->set_name(new_node_name);
3631     node_map_->AddNode(added_node->name(), added_node);
3632     added_node->clear_input();
3633     for (int i : const_inputs) {
3634       added_node->add_input(node->input(i));
3635       node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3636                               added_node->name());
3637     }
3638 
3639     // Overwrite the first const input with the added node.
3640     node->set_input(const_inputs[0], added_node->name());
3641     node_map_->AddOutput(added_node->name(), node->name());
3642     nonconst_inputs.push_back(const_inputs[0]);
3643     // Compact the remaining inputs to the original node.
3644     std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3645     int idx = 0;
3646     for (int i : nonconst_inputs) {
3647       if (idx != i) {
3648         node->set_input(idx, node->input(i));
3649       }
3650       ++idx;
3651     }
3652     node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3653                                           const_inputs.size() - 1);
3654     (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3655     properties->ClearInputProperties(node->name());
3656     (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3657     return true;
3658   }
3659   return false;
3660 }
3661 
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3662 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3663                                                 GraphProperties* properties,
3664                                                 NodeDef* node) {
3665   // Partial constant folding for Concat which is not commutative, so
3666   // we have to preserve order and can only push consecutive runs of constant
3667   // inputs into sub-nodes.
3668   if (!IsConcat(*node) ||
3669       node->name().rfind("_partial_split_") != string::npos) {
3670     return false;
3671   }
3672   const int num_non_control_inputs = NumNonControlInputs(*node);
3673   if (num_non_control_inputs <= 3) return false;
3674   int axis_arg = -1;
3675   int begin = 0;
3676   int end = num_non_control_inputs;
3677   if (node->op() == "Concat") {
3678     begin = 1;
3679     axis_arg = 0;
3680   } else if (node->op() == "ConcatV2") {
3681     end = num_non_control_inputs - 1;
3682     axis_arg = num_non_control_inputs - 1;
3683   } else {
3684     return false;
3685   }
3686 
3687   // We search for consecutive runs of constant inputs in the range
3688   // [begin:end[ and push then down into child nodes.
3689   std::vector<std::pair<int, int>> constant_input_runs;
3690   int first = begin;
3691   int last = begin;
3692   while (last < end) {
3693     while (first < end && !IsReallyConstant(*node_map_->GetNode(
3694                               NodeName(node->input(first))))) {
3695       ++first;
3696     }
3697     // Invariant: node[first] is constant || first >= end.
3698     last = first + 1;
3699     while (last < end &&
3700            IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3701       ++last;
3702     }
3703     // Invariant: node[last] is not constant || last >= end
3704     // Discard intervals shorter than 2 elements.
3705     if (first < end && (last - first) > 1) {
3706       constant_input_runs.emplace_back(first, last);
3707     }
3708     first = last;
3709   }
3710 
3711   // Skip if all inputs are constant, and let constant folding take over.
3712   if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3713                                       constant_input_runs[0].first == begin &&
3714                                       constant_input_runs[0].second == end)) {
3715     return false;
3716   }
3717   std::set<int> inputs_to_delete;
3718   for (auto interval : constant_input_runs) {
3719     // Push the constant inputs in the interval to a child node than can be
3720     // constant folded.
3721     string new_node_name = OptimizedNodeName(*node, "_partial_split");
3722     do {
3723       new_node_name += strings::StrCat("_", interval.first);
3724     } while (node_map_->NodeExists(new_node_name));
3725 
3726     NodeDef* added_node = optimized_graph->add_node();
3727     *added_node = *node;
3728     added_node->set_op("ConcatV2");
3729     added_node->set_name(new_node_name);
3730     node_map_->AddNode(added_node->name(), added_node);
3731     added_node->clear_input();
3732     for (int i = interval.first; i < interval.second; ++i) {
3733       added_node->add_input(node->input(i));
3734       node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3735       if (i != interval.first) {
3736         inputs_to_delete.insert(i);
3737       }
3738     }
3739     added_node->add_input(node->input(axis_arg));
3740     (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3741     node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3742 
3743     // Overwrite the first constant input with the result of the added
3744     // child node.
3745     node->set_input(interval.first, added_node->name());
3746   }
3747   if (!inputs_to_delete.empty()) {
3748     // Fix up the inputs to the original node.
3749     protobuf::RepeatedPtrField<string> tmp;
3750     tmp.Swap(node->mutable_input());
3751     for (int i = 0; i < tmp.size(); ++i) {
3752       if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3753         node->add_input(tmp.Get(i));
3754       }
3755     }
3756     (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3757     properties->ClearInputProperties(node->name());
3758   }
3759   return true;
3760 }
3761 
GetConcatAxis(const NodeDef & node,int * axis)3762 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3763   if (node.op() != "ConcatV2") {
3764     return false;
3765   }
3766   int axis_idx = node.input_size() - 1;
3767   while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3768     --axis_idx;
3769   }
3770   if (axis_idx <= 0) {
3771     return false;
3772   }
3773   Tensor axis_tensor;
3774   if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3775     return false;
3776   }
3777   *axis = axis_tensor.dtype() == DT_INT64
3778               ? static_cast<int>(axis_tensor.scalar<int64>()())
3779               : axis_tensor.scalar<int32>()();
3780   return true;
3781 }
3782 
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3783 bool ConstantFolding::MergeConcat(bool use_shape_info,
3784                                   GraphProperties* properties,
3785                                   GraphDef* optimized_graph, NodeDef* node) {
3786   // We only optimize for ConcatV2.
3787   int axis;
3788   if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3789       nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3790       node_map_->GetOutputs(node->name()).size() != 1) {
3791     return false;
3792   }
3793 
3794   // If all inputs are constant, don't merge and let folding take case of it.
3795   const int num_regular_inputs = NumNonControlInputs(*node);
3796   bool all_inputs_are_const = true;
3797   for (int i = 0; i < num_regular_inputs - 1; ++i) {
3798     const NodeDef* input_node = node_map_->GetNode(node->input(i));
3799     if (!IsReallyConstant(*input_node)) {
3800       all_inputs_are_const = false;
3801       break;
3802     }
3803   }
3804   if (all_inputs_are_const) return false;
3805 
3806   NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3807   int parent_axis;
3808   if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3809     return false;
3810   }
3811 
3812   // Make a pass over the parent inputs to see if any of them have explicit
3813   // device() fields set, and if different inputs are on different tasks.  If
3814   // so, this concat of concats may have been carefully constructed to be a
3815   // two-stage concat, and we don't want to undo that here.
3816   string task, device;
3817   absl::flat_hash_set<string> unique_input_tasks;
3818   const int n_parent_inputs = NumNonControlInputs(*parent);
3819   // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1).  The
3820   // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3821   // node, which we don't want to consider here.
3822   for (int i = 0; i < n_parent_inputs - 1; ++i) {
3823     const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3824     if (!input_node->device().empty() &&
3825         tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3826                                                      &task, &device)) {
3827       unique_input_tasks.insert(task);
3828       if (unique_input_tasks.size() >= 2) {
3829         // More than one input task represented in the device specifications
3830         // of the parent's input nodes.  Don't mess with this.
3831         return false;
3832       }
3833     }
3834   }
3835 
3836   protobuf::RepeatedPtrField<string> parent_inputs;
3837   parent_inputs.Swap(parent->mutable_input());
3838   // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3839   // collapse it into the parent multiple times? Probably not.
3840   for (const auto& input : parent_inputs) {
3841     if (IsSameInput(input, node->name())) {
3842       for (int j = 0; j < num_regular_inputs - 1; ++j) {
3843         // Add tensor inputs to first child concat tensors (except the final
3844         // axis input) to the parent's inputs.
3845         parent->add_input(node->input(j));
3846         node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3847       }
3848     } else {
3849       parent->add_input(input);
3850     }
3851   }
3852   // Forward Add control inputs
3853   const int num_inputs = node->input_size();
3854   for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3855     parent->add_input(node->input(i));
3856     node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3857     node->mutable_input()->RemoveLast();
3858   }
3859   (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3860   DedupControlInputs(parent);
3861   ReplaceOperationWithNoOp(node, properties, optimized_graph);
3862 
3863   return true;
3864 }
3865 
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3866 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3867     NodeDef* node, GraphDef* optimized_graph) {
3868   auto add_quantized_out = [this, node, optimized_graph](
3869                                const string& out_const_name, int index) {
3870     NodeDef* out_node = optimized_graph->add_node();
3871     graph_modified_ = true;
3872     Tensor value(DT_FLOAT, TensorShape({}));
3873     const bool is_min = index == 1;
3874     const DataType type_attr = node->attr().at("dtype").type();
3875 
3876     value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3877                                     : QuantizedTypeMaxAsFloat(type_attr);
3878     TF_RETURN_IF_ERROR(
3879         CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3880     node_map_->AddNode(out_const_name, out_node);
3881     out_node->set_device(node->device());
3882     // Copy all inputs from node.
3883     out_node->mutable_input()->CopyFrom(node->input());
3884     for (const string& input : out_node->input()) {
3885       node_map_->AddOutput(NodeName(input), out_const_name);
3886     }
3887 
3888     // Update output nodes consuming node:index to new const node.
3889     string old_input = absl::StrCat(node->name(), ":", index);
3890     int old_node_count = 0;
3891     // We make a copy since the set might change.
3892     auto outputs = node_map_->GetOutputs(node->name());
3893     for (const auto& output : outputs) {
3894       for (int i = 0; i < output->input_size(); ++i) {
3895         if (output->input(i) == old_input) {
3896           output->set_input(i, out_const_name);
3897           node_map_->AddOutput(out_const_name, output->name());
3898         } else if (NodeName(output->input(i)) == node->name()) {
3899           ++old_node_count;
3900         }
3901       }
3902       if (old_node_count == 0) {
3903         node_map_->RemoveOutput(node->name(), output->name());
3904       }
3905     }
3906 
3907     return Status::OK();
3908   };
3909   const string min_out_const_name =
3910       OptimizedNodeName(*node, "-quantized_matmul_min_out");
3911   const string max_out_const_name =
3912       OptimizedNodeName(*node, "-quantized_matmul_max_out");
3913   if (node_map_->GetNode(min_out_const_name) == nullptr &&
3914       node_map_->GetNode(max_out_const_name) == nullptr) {
3915     TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3916     TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3917   } else {
3918     return errors::Internal(absl::Substitute(
3919         "Can't create Const for QuantizedMatMul min_out/max_out of "
3920         "node '$0' because of node name conflict",
3921         node->name()));
3922   }
3923   return Status::OK();
3924 }
3925 
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphDef * optimized_graph)3926 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3927                                             GrapplerItem* item,
3928                                             GraphDef* optimized_graph) {
3929   graph_ = &item->graph;
3930   node_map_.reset(new NodeMap(graph_));
3931   nodes_allowlist_.clear();
3932   // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3933   // has a single fanout, it would be rewritten as a constant with the same
3934   // node name, and therefore users are still able to fetch it. This is not
3935   // the case if the node has multiple fanouts, and constant folding would
3936   // replace the node with multiple constants (each for one fanout) with
3937   // new names, and as a result users would not be able to fetch the node any
3938   // more with the original node name.
3939   for (const auto& fetch : item->fetch) {
3940     const NodeDef* fetch_node = node_map_->GetNode(fetch);
3941     if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3942       nodes_allowlist_.insert(fetch_node->name());
3943     }
3944   }
3945 
3946   GraphProperties properties(*item);
3947   // It's possible to feed a placeholder with a tensor of any shape: make sure
3948   // that the shape inference deals with this conservatively unless we're in
3949   // aggressive mode.
3950   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3951   Status s = properties.InferStatically(assume_valid_feeds,
3952                                         /*aggressive_shape_inference=*/false,
3953                                         /*include_input_tensor_values=*/false,
3954                                         /*include_output_tensor_values=*/true);
3955 
3956   const bool can_use_shape_info = s.ok();
3957   VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
3958 
3959   absl::flat_hash_set<string> nodes_to_not_simplify;
3960   if (can_use_shape_info) {
3961     TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3962     TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3963     TF_RETURN_IF_ERROR(
3964         FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
3965   } else {
3966     *optimized_graph = *graph_;
3967   }
3968   node_map_.reset(new NodeMap(optimized_graph));
3969   TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3970                                    &properties, &nodes_to_not_simplify));
3971 
3972   return Status::OK();
3973 }
3974 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3975 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3976                                  GraphDef* optimized_graph) {
3977   // TensorFlow flushes denormals to zero and rounds to nearest, so we do
3978   // the same here.
3979   port::ScopedFlushDenormal flush;
3980   port::ScopedSetRound round(FE_TONEAREST);
3981   nodes_to_preserve_ = item.NodesToPreserve();
3982   for (const auto& feed : item.feed) {
3983     feed_nodes_.insert(NodeName(feed.first));
3984   }
3985 
3986   if (cpu_device_ == nullptr) {
3987     owned_device_.reset(new DeviceSimple());
3988     cpu_device_ = owned_device_.get();
3989   }
3990 
3991   graph_contains_assign_or_inplace_op_ = false;
3992   for (const NodeDef& node : item.graph.node()) {
3993     if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
3994       graph_contains_assign_or_inplace_op_ = true;
3995       break;
3996     }
3997   }
3998 
3999   has_fetch_ = !item.fetch.empty();
4000   GrapplerItem item_to_optimize = item;
4001   *optimized_graph = GraphDef();
4002   item_to_optimize.graph.Swap(optimized_graph);
4003   int64 node_count;
4004   do {
4005     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4006     graph_modified_ = false;
4007     item_to_optimize.graph.Swap(optimized_graph);
4008     optimized_graph->Clear();
4009     node_count = item_to_optimize.graph.node_size();
4010     TF_RETURN_IF_ERROR(
4011         RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
4012   } while (graph_modified_ || optimized_graph->node_size() != node_count);
4013   *optimized_graph->mutable_library() = item.graph.library();
4014   *optimized_graph->mutable_versions() = item.graph.versions();
4015 
4016   return Status::OK();
4017 }
4018 
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)4019 void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
4020                                const GraphDef& optimize_output, double result) {
4021   // Nothing to do for ConstantFolding.
4022 }
4023 
4024 }  // namespace grappler
4025 }  // namespace tensorflow
4026