• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19 
20 #include <cmath>
21 
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57 
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61 
62 // We only fold/materialize constants smaller than 100kB.
63 const int64_t kMaxConstantSize = 100 * 1024;
64 
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68   Tensor tensor;
69   if (!tensor.FromProto(proto)) {
70     return false;
71   }
72   auto values = tensor.flat<T>();
73   for (int i = 0; i < tensor.NumElements(); ++i) {
74     if (values(i) != value) {
75       return false;
76     }
77   }
78   return true;
79 }
80 
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85                           GraphDef* graph, NodeMap* node_map) {
86   bool already_exists = false;
87   for (const string& input : node->input()) {
88     if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89       already_exists = true;
90       break;
91     }
92   }
93   if (!already_exists) {
94     const string ctrl_dep =
95         ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96     node->add_input(ctrl_dep);
97     node_map->AddOutput(NodeName(ctrl_input), node->name());
98   }
99   return !already_exists;
100 }
101 
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104                              GraphDef* graph, NodeMap* node_map) {
105   bool removed_input = false;
106   bool update_node_map = true;
107   const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108   for (int i = 0; i < node->input_size(); ++i) {
109     const string& input = node->input(i);
110     if (old_input_ctrl_dep == input) {
111       if (IsControlInput(input)) {
112         node->mutable_input()->SwapElements(i, node->input_size() - 1);
113         node->mutable_input()->RemoveLast();
114         removed_input = true;
115       } else {
116         // There is a non-control input from the same node.
117         // Don't remove the output from the NodeMap.
118         update_node_map = false;
119       }
120     }
121   }
122   if (update_node_map) {
123     node_map->RemoveOutput(NodeName(old_input), node->name());
124   }
125   return removed_input;
126 }
127 
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129   AttrSlice attrs(node);
130   for (const auto& attr : attrs) {
131     if (attr.first.find("_tpu_") != attr.first.npos) {
132       return true;
133     }
134   }
135   return false;
136 }
137 
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140   return a != b;
141 }
142 
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145   return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147 
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150   return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152 
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154   switch (data_type) {
155     case DT_QINT8:
156       return Eigen::NumTraits<qint8>::lowest();
157     case DT_QUINT8:
158       return Eigen::NumTraits<quint8>::lowest();
159     case DT_QINT16:
160       return Eigen::NumTraits<qint16>::lowest();
161     case DT_QUINT16:
162       return Eigen::NumTraits<quint16>::lowest();
163     case DT_QINT32:
164       return Eigen::NumTraits<qint32>::lowest();
165     default:
166       return 0.0f;
167   }
168 }
169 
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171   switch (data_type) {
172     case DT_QINT8:
173       return Eigen::NumTraits<qint8>::highest();
174     case DT_QUINT8:
175       return Eigen::NumTraits<quint8>::highest();
176     case DT_QINT16:
177       return Eigen::NumTraits<qint16>::highest();
178     case DT_QUINT16:
179       return Eigen::NumTraits<quint16>::highest();
180     case DT_QINT32:
181       return Eigen::NumTraits<qint32>::highest();
182     default:
183       return 0.0f;
184   }
185 }
186 
187 }  // namespace
188 
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190                                  DeviceBase* cpu_device,
191                                  bool disable_compressed_tensor_optimization,
192                                  bool fold_quantization_emulation)
193     : opt_level_(opt_level),
194       cpu_device_(cpu_device),
195       disable_compressed_tensor_optimization_(
196           disable_compressed_tensor_optimization),
197       fold_quantization_emulation_(fold_quantization_emulation) {
198   resource_mgr_.reset(new ResourceMgr());
199 }
200 
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202                                  bool disable_compressed_tensor_optimization,
203                                  bool fold_quantization_ops)
204     : ConstantFolding(RewriterConfig::ON, cpu_device,
205                       disable_compressed_tensor_optimization,
206                       fold_quantization_ops) {}
207 
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210                                              GraphDef* graph,
211                                              NodeMap* node_map) {
212   if (IsControlInput(input_name)) {
213     return input_name;
214   }
215   const NodeDef* node = node_map->GetNode(input_name);
216   // Sanity check for missing node.
217   if (!node) {
218     return input_name;
219   }
220   if (!IsSwitch(*node)) {
221     return AsControlDependency(*node);
222   } else {
223     // We can't anchor control dependencies directly on the switch node: unlike
224     // other nodes only one of the outputs of the switch node will be generated
225     // when the switch node is executed, and we need to make sure the control
226     // dependency is only triggered when the corresponding output is triggered.
227     // We start by looking for an identity node connected to the output of the
228     // switch node, and use it to anchor the control dependency.
229     for (const NodeDef* output : node_map->GetOutputs(node->name())) {
230       if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
231         if (IsSameInput(node->input(0), input_name)) {
232           return AsControlDependency(*output);
233         }
234       }
235     }
236     // We haven't found an existing node where we can anchor the control
237     // dependency: add a new identity node.
238     int port = 0;
239     string ctrl_dep_name = ParseNodeName(input_name, &port);
240     strings::StrAppend(&ctrl_dep_name, "_", port);
241     ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
242     const DataType output_type = node->attr().at("T").type();
243 
244     NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
245     if (added_node == nullptr) {
246       added_node = graph->add_node();
247       added_node->set_name(ctrl_dep_name);
248       added_node->set_op("Identity");
249       added_node->set_device(node->device());
250 
251       (*added_node->mutable_attr())["T"].set_type(output_type);
252       *added_node->add_input() = input_name;
253       node_map->AddNode(added_node->name(), added_node);
254       node_map->AddOutput(node->name(), added_node->name());
255     }
256     return AsControlDependency(*added_node);
257   }
258 }
259 
260 // Forward inputs at the given indices to outputs and add a control dependency
261 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)262 bool ConstantFolding::ForwardInputs(NodeDef* node,
263                                     absl::Span<const int> inputs_to_forward) {
264   for (int input_idx : inputs_to_forward) {
265     if (input_idx < 0 || input_idx >= node->input_size()) {
266       return false;
267     }
268   }
269 
270   const auto& tmp = node_map_->GetOutputs(node->name());
271   const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
272   bool updated_graph = false;
273   for (int input_idx : inputs_to_forward) {
274     const string& input = node->input(input_idx);
275     if (IsControlInput(input) && consumers.size() > 1) {
276       continue;
277     }
278     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
279     if (input_node == nullptr) {
280       LOG(ERROR) << "Bad input: " << input;
281       break;
282     }
283     // Update each consumer.
284     for (NodeDef* consumer : consumers) {
285       bool add_dep = false;
286       for (int consumer_input_idx = 0;
287            consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
288         const string& consumer_input = consumer->input(consumer_input_idx);
289         if (IsControlInput(consumer_input)) {
290           break;
291         }
292         // It is illegal to add control dependencies to _Retval nodes, so we
293         // can't bypass value producing `node` and forward inputs to `consumer`.
294         if (IsRetval(*consumer)) {
295           break;
296         }
297         int output_idx;
298         const string input_node_name =
299             ParseNodeName(consumer_input, &output_idx);
300         if (input_node_name == node->name() && output_idx == input_idx) {
301           consumer->set_input(consumer_input_idx, input);
302           // We will keep the input from the node through a control
303           // dependency, so we only need to add the consumer as an output
304           // for the input node.
305           node_map_->AddOutput(NodeName(input), consumer->name());
306           add_dep = true;
307         }
308       }
309       if (add_dep) {
310         consumer->add_input(AsControlDependency(node->name()));
311         updated_graph = true;
312       }
313     }
314   }
315 
316   if (updated_graph) {
317     for (NodeDef* consumer : consumers) {
318       DedupControlInputs(consumer);
319     }
320   }
321   return updated_graph;
322 }
323 
324 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64_t value,const DataType & type,const int index,Tensor * tensor)325 static Status PutValueIntoTensor(const int64_t value, const DataType& type,
326                                  const int index, Tensor* tensor) {
327   if (type == DT_INT32) {
328     if (value >= INT_MAX) {
329       return Status(error::INVALID_ARGUMENT, "int32 overflow");
330     }
331     tensor->flat<int32>()(index) = static_cast<int32>(value);
332   } else {
333     tensor->flat<int64>()(index) = value;
334   }
335   return Status::OK();
336 }
337 
338 // Writes the given tensor shape into the given tensor.
339 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)340 static Status ConvertShapeToConstant(const string& op, const DataType& type,
341                                      const PartialTensorShape& shp,
342                                      Tensor* tensor) {
343   if (op == "Shape" || op == "ShapeN") {
344     *tensor = Tensor(type, TensorShape({shp.dims()}));
345     for (int i = 0; i < shp.dims(); ++i) {
346       TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
347     }
348   } else if (op == "Size") {
349     int64_t size = 1;
350     for (int i = 0; i < shp.dims(); ++i) {
351       size *= shp.dim_size(i);
352     }
353     *tensor = Tensor(type, TensorShape({}));
354     TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
355   } else {
356     CHECK_EQ(op, "Rank");
357     *tensor = Tensor(type, TensorShape({}));
358     TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
359   }
360   return Status::OK();
361 }
362 
363 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const364 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
365                                           StringPiece suffix) const {
366   return node_map_->NodeExists(OptimizedNodeName(node, suffix));
367 }
368 
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const369 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
370                                           StringPiece suffix) const {
371   return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
372                              kConstantFoldingConst);
373 }
374 
IsReallyConstant(const NodeDef & node) const375 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
376   if (!IsConstant(node)) {
377     return false;
378   }
379   // If the node is fed it's not constant anymore.
380   return feed_nodes_.find(node.name()) == feed_nodes_.end();
381 }
382 
383 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)384 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
385                                              Tensor* tensor) {
386   const NodeDef* node = node_map_->GetNode(node_name_or_input);
387   return node != nullptr && IsReallyConstant(*node) &&
388          CheckAttrExists(*node, "value").ok() &&
389          tensor->FromProto(node->attr().at("value").tensor());
390 }
391 
392 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)393 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
394   // We may add some nodes to the graph to encode control dependencies and hold
395   // the materialized shapes: there is no need to process these added nodes, so
396   // only iterate over the nodes of the input graph.
397   const int node_count = graph_->node_size();
398   for (int node_idx = 0; node_idx < node_count; ++node_idx) {
399     NodeDef* node = graph_->mutable_node(node_idx);
400     const string op = node->op();
401     if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
402         op != "TensorArraySizeV3") {
403       continue;
404     }
405     const std::vector<OpInfo::TensorProperties>& output =
406         properties.GetOutputProperties(node->name());
407     const std::vector<OpInfo::TensorProperties>& input =
408         properties.GetInputProperties(node->name());
409     if (input.empty() || output.empty()) {
410       continue;
411     }
412 
413     if (op == "Shape" || op == "Size" || op == "Rank") {
414       CHECK_EQ(1, output.size());
415       CHECK_EQ(1, input.size());
416 
417       const DataType type = output[0].dtype();
418       CHECK(type == DT_INT32 || type == DT_INT64);
419       const PartialTensorShape shape(input[0].shape());
420 
421       if ((op != "Rank" && !shape.IsFullyDefined()) ||
422           (op == "Rank" && shape.unknown_rank())) {
423         continue;
424       }
425 
426       Tensor constant_value(type);
427       if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
428         continue;
429       }
430 
431       // TODO(rmlarsen): Remove this workaround for b/150861569
432       // The bug involves an expression of the form Shape(ExpandDims(x)
433       // with an incorrectly inferred zero-size first dimension.
434       if (op == "Shape") {
435         if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
436       }
437 
438       // Repurpose the existing node to be the constant.
439       // Device placement is preserved.
440       graph_modified_ = true;
441       node->set_op("Const");
442       EraseRegularNodeAttributes(node);
443       (*node->mutable_attr())["dtype"].set_type(type);
444       constant_value.AsProtoTensorContent(
445           (*node->mutable_attr())["value"].mutable_tensor());
446 
447       // Turn the data input into a control dependency: this is needed to
448       // ensure that the constant value will only be run in the
449       // cases where the shape/rank/size would have been run in
450       // the original graph.
451       string ctrl_dep =
452           AddControlDependency(node->input(0), graph_, node_map_.get());
453       node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
454       node->set_input(0, ctrl_dep);
455       // Done with the Shape/Size/Rank node, move to the next node.
456       continue;
457     }
458 
459     if (op == "TensorArraySizeV3") {
460       const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
461       if (array->input_size() == 0 ||
462           (array->attr().count("dynamic_size") != 0 &&
463            array->attr().at("dynamic_size").b())) {
464         continue;
465       }
466       const NodeDef* array_size =
467           CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
468       if (IsReallyConstant(*array_size)) {
469         // Don't materialize 0 sizes to avoid triggering incorrect static
470         // checks. A 0 sized array that can't grow isn't useful anyway.
471         if (array_size->attr().count("value") == 0) {
472           continue;
473         }
474         const TensorProto& raw_val = array_size->attr().at("value").tensor();
475         if (raw_val.dtype() != DT_INT32) {
476           continue;
477         }
478         Tensor value(raw_val.dtype(), raw_val.tensor_shape());
479         if (!value.FromProto(raw_val)) {
480           continue;
481         }
482         if (value.flat<int32>()(0) == 0) {
483           continue;
484         }
485 
486         graph_modified_ = true;
487         node->set_op("Const");
488         *node->mutable_attr() = array_size->attr();
489         node->set_input(0, AsControlDependency(NodeName(node->input(0))));
490         node->set_input(1, AddControlDependency(NodeName(node->input(1)),
491                                                 graph_, node_map_.get()));
492       }
493       continue;
494     }
495 
496     // Handle ShapeN materialization case.
497     // It's possible that not all input tensors have known shapes.
498     CHECK_EQ(op, "ShapeN");
499     CHECK_EQ(input.size(), output.size());
500     const NodeDef* const shape_n_node = node;
501     for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
502          ++port_idx) {
503       const DataType type = output[port_idx].dtype();
504       CHECK(type == DT_INT32 || type == DT_INT64);
505       const PartialTensorShape shape(input[port_idx].shape());
506       if (!shape.IsFullyDefined()) {
507         continue;
508       }
509       Tensor constant_value(type);
510       auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
511       if (!status.ok()) {
512         continue;
513       }
514 
515       // We make a copy because we mutate the nodes.
516       auto fanouts = node_map_->GetOutputs(shape_n_node->name());
517       // Find all nodes consuming this shape and connect them through the new
518       // constant node instead.
519       for (NodeDef* output : fanouts) {
520         // Track whether there are any direct edges left between shape_n_node
521         // and this output node after the transformation.
522         bool direct_edges_exist = false;
523         for (int k = 0; k < output->input_size(); ++k) {
524           int port;
525           const string node_name = ParseNodeName(output->input(k), &port);
526           if (node_name == shape_n_node->name() && port == port_idx) {
527             // Create a const node as ShapeN's output if not already.
528             const string const_name = OptimizedNodeName(
529                 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
530             if (node_map_->GetNode(const_name) == nullptr) {
531               NodeDef* added_node = graph_->add_node();
532               added_node->set_name(const_name);
533               added_node->set_op("Const");
534               added_node->set_device(shape_n_node->device());
535               node_map_->AddNode(added_node->name(), added_node);
536               (*added_node->mutable_attr())["dtype"].set_type(type);
537               constant_value.AsProtoTensorContent(
538                   (*added_node->mutable_attr())["value"].mutable_tensor());
539               // We add a control dependency to the original ShapeN node,
540               // so that the node will only be run if all inputs of the
541               // original ShapeN node are run.
542               string ctrl_dep = AddControlDependency(shape_n_node->name(),
543                                                      graph_, node_map_.get());
544               *added_node->add_input() = ctrl_dep;
545               node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
546             }
547             *output->mutable_input(k) = const_name;
548             node_map_->AddOutput(const_name, output->name());
549             graph_modified_ = true;
550           }
551           if (node_name == shape_n_node->name() && port != port_idx) {
552             direct_edges_exist = true;
553           }
554         }
555         if (!direct_edges_exist) {
556           node_map_->RemoveOutput(node->name(), output->name());
557         }
558       }
559     }
560   }
561 
562   return Status::OK();
563 }
564 
565 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)566 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
567                   BCast::Vec* shape, int64* min_id) {
568   if (shape_node.op() == "Shape") {
569     const std::vector<OpInfo::TensorProperties>& prop1 =
570         properties.GetInputProperties(shape_node.name());
571     if (prop1.size() != 1) {
572       return false;
573     }
574     const TensorShapeProto& shp = prop1[0].shape();
575     if (shp.unknown_rank()) {
576       return false;
577     }
578     for (const auto& dim : shp.dim()) {
579       shape->push_back(dim.size());
580       *min_id = std::min<int64>(*min_id, dim.size());
581     }
582   } else {
583     if (shape_node.attr().count("value") == 0) {
584       return false;
585     }
586     const TensorProto& raw_val = shape_node.attr().at("value").tensor();
587     if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
588       return false;
589     }
590     Tensor value(raw_val.dtype(), raw_val.tensor_shape());
591     if (!value.FromProto(raw_val)) {
592       return false;
593     }
594     for (int j = 0; j < value.NumElements(); ++j) {
595       if (raw_val.dtype() == DT_INT64) {
596         shape->push_back(value.vec<int64>()(j));
597       } else {
598         shape->push_back(value.vec<int>()(j));
599       }
600     }
601   }
602   return true;
603 }
604 }  // namespace
605 
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)606 Status ConstantFolding::MaterializeBroadcastGradientArgs(
607     const NodeDef& node, const GraphProperties& properties) {
608   const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
609   const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
610   if (shape_node1 == nullptr ||
611       (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
612       shape_node2 == nullptr ||
613       (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
614     return Status::OK();
615   }
616 
617   // Don't optimize this again if it was already optimized and folded.
618   if (OptimizedNodeExists(node, "-folded-1") ||
619       OptimizedNodeExists(node, "-folded-2")) {
620     return Status::OK();
621   }
622   int64_t min_id = 0;
623   BCast::Vec shape1;
624   if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
625     return Status::OK();
626   }
627   BCast::Vec shape2;
628   if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
629     return Status::OK();
630   }
631   // A value of -1 means we don't known anything about the dimension. Replace
632   // the -1 values with unique dimension ids since we don't want two '-1'
633   // dimensions to be considered equal.
634   for (auto& id : shape1) {
635     if (id == -1) {
636       id = --min_id;
637     }
638   }
639   for (auto& id : shape2) {
640     if (id == -1) {
641       id = --min_id;
642     }
643   }
644 
645   // Beware: the reduction dimensions computed by the BCast class are valid iff
646   // we assume that two distinct symbolic dimensions can't be equal and a
647   // symbolic dimension can't be equal to 1. This is often but not always true,
648   // so to make this optimization safe we filter out these cases.
649   const int common_dims = std::min(shape1.size(), shape2.size());
650   for (int i = 0; i < common_dims; ++i) {
651     if (shape1[i] >= 0 && shape2[i] >= 0) {
652       continue;
653     }
654     if (shape1[i] != shape2[i]) {
655       // We're either dealing with 2 different symbolic dimensions or a symbolic
656       // and a know dimensions. We can't be sure whether both are equal or not,
657       // so we can't be sure whether we'll be broadcasting or not.
658       return Status::OK();
659     }
660   }
661   // These extra dims could be equal to 1, in which case there is no
662   // broadcasting. It could also be greater than 1, in which case there would
663   // be broadcasting. Since we don't know, we'll just punt.
664   for (int i = common_dims, end = shape1.size(); i < end; ++i) {
665     if (shape1[i] < 0) {
666       return Status::OK();
667     }
668   }
669   for (int i = common_dims, end = shape2.size(); i < end; ++i) {
670     if (shape2[i] < 0) {
671       return Status::OK();
672     }
673   }
674 
675   BCast bcast(shape1, shape2);
676   if (!bcast.IsValid()) {
677     return Status::OK();
678   }
679 
680   BCast::Vec reduce_dims[2];
681   reduce_dims[0] = bcast.grad_x_reduce_idx();
682   reduce_dims[1] = bcast.grad_y_reduce_idx();
683 
684   TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
685   const DataType type = node.attr().at("T").type();
686   NodeDef* out[2];
687   for (int j = 0; j < 2; ++j) {
688     int reduction_indices = reduce_dims[j].size();
689     Tensor value(type, TensorShape({reduction_indices}));
690     for (int i = 0; i < reduction_indices; ++i) {
691       if (type == DT_INT32) {
692         value.vec<int32>()(i) = reduce_dims[j][i];
693       } else {
694         value.vec<int64>()(i) = reduce_dims[j][i];
695       }
696     }
697     string const_name =
698         OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
699     out[j] = node_map_->GetNode(const_name);
700     if (out[j] == nullptr) {
701       out[j] = graph_->add_node();
702       TF_RETURN_IF_ERROR(
703           CreateNodeDef(const_name, TensorValue(&value), out[j]));
704       out[j]->set_device(node.device());
705       node_map_->AddNode(const_name, out[j]);
706       string ctrl_dep =
707           AddControlDependency(node.name(), graph_, node_map_.get());
708       *out[j]->add_input() = ctrl_dep;
709       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
710     }
711   }
712 
713   // We make a copy here since we might mutate the set.
714   const auto outputs = node_map_->GetOutputs(node.name());
715   for (NodeDef* output : outputs) {
716     for (int k = 0; k < output->input_size(); ++k) {
717       int port;
718       string node_name = ParseNodeName(output->input(k), &port);
719       if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
720         *output->mutable_input(k) = out[port]->name();
721         node_map_->UpdateInput(output->name(), node_name, out[port]->name());
722       }
723     }
724   }
725 
726   return Status::OK();
727 }
728 
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)729 Status ConstantFolding::MaterializeReductionIndices(
730     NodeDef* node, const GraphProperties& properties) {
731   if (node->input_size() < 2) {
732     return Status::OK();
733   }
734   const NodeDef* indices = node_map_->GetNode(node->input(1));
735   if (!indices || IsReallyConstant(*indices)) {
736     // The reduction indices are already constant, there's nothing to do.
737     return Status::OK();
738   }
739 
740   const std::vector<OpInfo::TensorProperties>& input_props =
741       properties.GetInputProperties(node->name());
742   if (input_props.size() != 2) {
743     return Status::OK();
744   }
745   const OpInfo::TensorProperties& input_prop = input_props[0];
746   if (input_prop.shape().unknown_rank()) {
747     // We can't do anything if we don't know the rank of the input.
748     return Status::OK();
749   }
750   const int input_rank = input_prop.shape().dim_size();
751   if (input_rank < 1) {
752     // Unexpected graph, don't try to change it.
753     return Status::OK();
754   }
755   const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
756   DataType dtype = reduction_indices_prop.dtype();
757   if (dtype != DT_INT32 && dtype != DT_INT64) {
758     return Status::OK();
759   }
760   PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
761   const int num_reduction_indices = reduction_indices_shape.num_elements();
762 
763   const std::vector<OpInfo::TensorProperties>& output_props =
764       properties.GetOutputProperties(node->name());
765   if (output_props.size() != 1) {
766     return Status::OK();
767   }
768   const OpInfo::TensorProperties& output_prop = output_props[0];
769   const int output_rank =
770       output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
771 
772   bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
773   if (!full_reduction) {
774     // A full reduction will generate a tensor of one of the shapes
775     // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
776     // elements in the output of the reduction, we may deduce it from reshape
777     // nodes following it.
778     for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
779       full_reduction = false;
780       if (!IsReshape(*fanout)) {
781         return Status::OK();
782       }
783       const std::vector<OpInfo::TensorProperties>& reshape_props =
784           properties.GetOutputProperties(fanout->name());
785       if (reshape_props.size() != 1) {
786         return Status::OK();
787       }
788       const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
789       PartialTensorShape shape(reshape_prop.shape());
790       if (shape.num_elements() != 1) {
791         return Status::OK();
792       } else {
793         full_reduction = true;
794       }
795     }
796     if (!full_reduction) {
797       return Status::OK();
798     }
799   }
800 
801   // We know it's a full reduction. We can generate the full set of indices to
802   // reduce as a constant node.
803   string const_name = OptimizedNodeName(*node, "-reduction_indices");
804   if (node_map_->GetNode(const_name)) {
805     return Status::OK();
806   }
807   NodeDef* reduction_indices = graph_->add_node();
808   Tensor value(dtype, TensorShape({input_rank}));
809   for (int i = 0; i < input_rank; ++i) {
810     if (dtype == DT_INT32) {
811       value.vec<int32>()(i) = i;
812     } else {
813       value.vec<int64>()(i) = i;
814     }
815   }
816   TF_RETURN_IF_ERROR(
817       CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
818 
819   reduction_indices->set_device(node->device());
820   string ctrl_dep =
821       AddControlDependency(node->input(1), graph_, node_map_.get());
822   *reduction_indices->add_input() = ctrl_dep;
823   node_map_->AddNode(const_name, reduction_indices);
824   node_map_->AddOutput(NodeName(ctrl_dep), const_name);
825 
826   node->set_input(1, reduction_indices->name());
827   node_map_->UpdateInput(node->name(), indices->name(),
828                          reduction_indices->name());
829 
830   return Status::OK();
831 }
832 
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)833 Status ConstantFolding::MaterializeConstantValuedNode(
834     NodeDef* node, const GraphProperties& properties) {
835   if (disable_compressed_tensor_optimization_) {
836     return Status::OK();
837   }
838   // Nodes that generate constant-valued outputs can be represented compactly in
839   // compressed format, regardless of their shape.
840   const std::vector<OpInfo::TensorProperties>& output_props =
841       properties.GetOutputProperties(node->name());
842   if (output_props.size() != 1) return Status::OK();
843   const auto& output_shape = output_props[0].shape();
844   if (!PartialTensorShape(output_shape).IsFullyDefined()) {
845     return Status::OK();
846   }
847   if (IsFill(*node)) {
848     const auto output_dtype = output_props[0].dtype();
849     NodeDef* input_node = nullptr;
850     for (int i = 0; i < 2; ++i) {
851       input_node = node_map_->GetNode(NodeName(node->input(i)));
852       if (input_node == nullptr || !IsReallyConstant(*input_node)) {
853         return Status::OK();
854       }
855     }
856     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
857 
858     // Copy the input tensor to the fill node, set the output shape and data
859     // type, and change the node type to Const.
860     TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
861     const TensorProto& input_tensor = input_node->attr().at("value").tensor();
862     if (!input_tensor.tensor_content().empty()) {
863       // Convert the value to repeated field format, so we can use the
864       // decompression mechanism to store only a single value in the constant
865       // node, even if the shape specified in the original Fill is large.
866       Tensor t;
867       if (!t.FromProto(input_tensor)) {
868         return errors::InvalidArgument(
869             "Could not construct Tensor form TensorProto in node: ",
870             input_node->name());
871       }
872       tensor->clear_tensor_content();
873       t.AsProtoField(tensor);
874     } else {
875       *tensor = input_tensor;
876     }
877     *(tensor->mutable_tensor_shape()) = output_shape;
878     (*node->mutable_attr())["dtype"].set_type(output_dtype);
879     node->mutable_attr()->erase("T");
880     node->mutable_attr()->erase("index_type");
881     node->set_op("Const");
882     for (int i = 0; i < 2; i++) {
883       // Change inputs to a control inputs.
884       const string ctrl_dep = AsControlDependency(node->input(i));
885       node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
886       node->set_input(i, ctrl_dep);
887     }
888     graph_modified_ = true;
889   } else {
890     double value =
891         (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
892     if (value >= 0) {
893       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
894           value, properties, output_shape, node, graph_));
895     }
896   }
897   return Status::OK();
898 }
899 
900 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)901 Status ConstantFolding::MaterializeOutputValues(
902     NodeDef* node, const GraphProperties& properties) {
903   const std::vector<OpInfo::TensorProperties>& output =
904       properties.GetOutputProperties(node->name());
905   if (output.size() != 1 || !output[0].has_value() ||
906       !IsFoldable(*node, &properties)) {
907     return Status::OK();
908   }
909 
910   // If this is a trivial Identity node with a constant input, just route the
911   // input around it.
912   if (IsIdentity(*node)) {
913     NodeDef* input = node_map_->GetNode(node->input(0));
914     if (IsReallyConstant(*input)) {
915       std::vector<int> inputs_to_forward;
916       std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
917       graph_modified_ = ForwardInputs(node, inputs_to_forward);
918       return Status::OK();
919     }
920   }
921   // Repurpose the existing node to be the constant.
922   // Device placement is preserved.
923   TensorProto value_copy = output[0].value();
924   return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
925                                             node, graph_);
926 }
927 
MaterializeConstants(const GraphProperties & properties)928 Status ConstantFolding::MaterializeConstants(
929     const GraphProperties& properties) {
930   const int node_count = graph_->node_size();
931   for (int i = 0; i < node_count; ++i) {
932     NodeDef& node = *graph_->mutable_node(i);
933     const string& op = node.op();
934     if (op == "BroadcastGradientArgs") {
935       TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
936     } else if (IsReduction(node)) {
937       TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
938     } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
939       TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
940     } else {
941       TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
942     }
943   }
944   return Status::OK();
945 }
946 
IsFoldable(const NodeDef & node,const GraphProperties * properties)947 bool ConstantFolding::IsFoldable(const NodeDef& node,
948                                  const GraphProperties* properties) {
949   string key = strings::StrCat(node.name(), "/", node.op());
950   auto it = maybe_foldable_nodes_.find(key);
951   if (it == maybe_foldable_nodes_.end()) {
952     it = maybe_foldable_nodes_
953              .emplace(std::move(key), MaybeFoldable(node, properties))
954              .first;
955   }
956   if (!it->second) {
957     return false;
958   } else {
959     return IsFoldableUncached(node, properties);
960   }
961 }
962 
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const963 bool ConstantFolding::IsFoldableUncached(
964     const NodeDef& node, const GraphProperties* properties) const {
965   // Folding not applicable to ops with no inputs.
966   if (node.input().empty()) {
967     return false;
968   }
969   // We can only fold nodes if all their inputs are known statically, except in
970   // the case of a merge node that propagate the first inputs that becomes
971   // available, and therefore only requires a single constant input to be
972   // foldable.
973   bool merge_has_constant_input = false;
974   const bool is_merge = IsMerge(node);
975   for (const auto& input : node.input()) {
976     if (IsControlInput(input)) {
977       continue;
978     }
979     const NodeDef* input_node = node_map_->GetNode(input);
980     if (!input_node) {
981       return false;
982     }
983     bool is_const = IsReallyConstant(*input_node);
984     if (is_const) {
985       // Don't fold strings constants for now since this causes problems with
986       // checkpointing.
987       if (input_node->attr().count("dtype") == 0 ||
988           input_node->attr().at("dtype").type() == DT_STRING) {
989         return false;
990       }
991       // Special case: If a Merge node has at least one constant input that
992       // does not depend on a control input, we can fold it.
993       merge_has_constant_input |= !HasControlInputs(*input_node);
994     } else if (!is_merge) {
995       return false;
996     }
997   }
998   if (is_merge && !merge_has_constant_input) return false;
999   if (disable_compressed_tensor_optimization_ &&
1000       (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
1001     return false;
1002 
1003   // If we know the output shapes, make sure that the outputs are small enough
1004   // to materialize.
1005   if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1006     const std::vector<OpInfo::TensorProperties>& input_props =
1007         properties->GetInputProperties(node.name());
1008     const std::vector<OpInfo::TensorProperties>& output_props =
1009         properties->GetOutputProperties(node.name());
1010     // Compute total size of inputs.
1011     int64_t input_size_bytes = 0;
1012     for (const auto& input_prop : input_props) {
1013       const PartialTensorShape input_shape(input_prop.shape());
1014       if (input_shape.IsFullyDefined()) {
1015         input_size_bytes +=
1016             input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1017       }
1018     }
1019     for (const auto& output_prop : output_props) {
1020       const PartialTensorShape output_shape(output_prop.shape());
1021       if (output_shape.IsFullyDefined()) {
1022         const int64_t num_bytes =
1023             output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1024         if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1025           // Do not fold nodes if the in-memory size of output is too large.
1026           // Notice that this is not exactly the same check used in
1027           // CreateNodeDef() where the actual encoded size is checked.
1028           return false;
1029         }
1030       }
1031     }
1032   }
1033 
1034   return true;
1035 }
1036 
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1037 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1038                                     const GraphProperties* properties) const {
1039   // Skip constants, they're already folded
1040   if (IsConstant(node)) {
1041     return false;
1042   }
1043   // Don't fold stateful ops such as TruncatedNormal.
1044   if (!IsFreeOfSideEffect(node)) {
1045     return false;
1046   }
1047 
1048   // Skips nodes that must be preserved except allowlisted nodes.
1049   if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1050       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1051     return false;
1052   }
1053 
1054   // Skip control flow nodes, they can't be folded.
1055   if (ModifiesFrameInfo(node)) {
1056     return false;
1057   }
1058 
1059   // Skips ops that don't benefit from folding.
1060   if (IsPlaceholder(node)) {
1061     return false;
1062   }
1063   // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1064   // have a valid output when executed.
1065   if (IsFakeParam(node)) {
1066     return false;
1067   }
1068 
1069   if (node.op() == "AccumulateNV2") {
1070     return false;
1071   }
1072   // Removing LoopCond nodes can screw up the partitioner.
1073   if (node.op() == "LoopCond") {
1074     return false;
1075   }
1076 
1077   if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1078     return false;
1079   }
1080 
1081   const string& op = node.op();
1082   if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1083       op.find("Reader") != string::npos) {
1084     return false;
1085   }
1086   if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1087     return false;
1088   }
1089 
1090   // Don't fold nodes that contain TPU attributes.
1091   // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1092   // properly forward custom attributes, b/119051778.
1093   if (HasTPUAttributes(node)) {
1094     return false;
1095   }
1096 
1097   const OpDef* op_def = nullptr;
1098   Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1099   if (!status.ok()) {
1100     return false;
1101   }
1102   // Don't fold ops without outputs.
1103   if (op_def->output_arg_size() == 0) {
1104     return false;
1105   }
1106   // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1107   // TODO(rmlarsen): Only do this for XLA_* devices.
1108   for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1109     if (output_arg.type() == DT_VARIANT) {
1110       return false;
1111     }
1112   }
1113 
1114   // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1115   // Such nodes could be introduced by an earlier constant folding pass and are
1116   // preserved in case users want to fetch their values; re-processing them
1117   // would lead to an error of adding a duplicated node to graph.
1118   const auto& outputs = node_map_->GetOutputs(node.name());
1119   if (outputs.empty() &&
1120       nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1121     return false;
1122   }
1123   return true;
1124 }
1125 
1126 namespace {
1127 
1128 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME)     \
1129   case DTYPE:                                      \
1130     t->add_##NAME##_val(static_cast<TYPE>(value)); \
1131     break;
1132 
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1133 Status CreateConstantTensorAttrValue(DataType type, double value,
1134                                      const TensorShapeProto& shape,
1135                                      AttrValue* attr_tensor) {
1136   TensorProto* t = attr_tensor->mutable_tensor();
1137   t->set_dtype(type);
1138   *t->mutable_tensor_shape() = shape;
1139   switch (type) {
1140     case DT_HALF:
1141       t->add_half_val(
1142           Eigen::numext::bit_cast<uint16>(static_cast<Eigen::half>(value)));
1143       break;
1144     case DT_BFLOAT16:
1145       t->add_half_val(
1146           Eigen::numext::bit_cast<uint16>(static_cast<bfloat16>(value)));
1147       break;
1148       SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1149       SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1150       SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
1151       SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
1152       SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1153       SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1154       SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1155       SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1156       SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1157       SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1158       SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1159       SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1160       SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1161       SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1162       SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1163       SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1164     default:
1165       return errors::InvalidArgument(
1166           "Unsupported type in CreateConstantTensorAttrValue: ",
1167           DataTypeString(type));
1168   }
1169   return Status::OK();
1170 }
1171 
1172 #undef SET_TENSOR_CAL_CASE
1173 
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1174 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1175                                     const GraphProperties& properties) {
1176   DataType dtype = DT_INVALID;
1177   if (node.attr().count("T") == 1) {
1178     dtype = node.attr().at("T").type();
1179   } else if (node.attr().count("dtype") == 1) {
1180     dtype = node.attr().at("dtype").type();
1181   } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1182     dtype = DT_BOOL;
1183   } else {
1184     auto output_props = properties.GetOutputProperties(node.name());
1185     if (!output_props.empty()) {
1186       dtype = output_props[0].dtype();
1187     }
1188   }
1189   return dtype;
1190 }
1191 
1192 // Checks whether the shape of the const input of the Mul op is valid to perform
1193 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1194 bool IsValidConstShapeForMulConvPushDown(
1195     const string& data_format, const TensorShapeProto& filter_shape,
1196     const TensorShapeProto& mul_const_input_shape) {
1197   // If the const is a scalar, or it has fewer or same number of dimensions
1198   // than the filter and it only has single element, the optimization should
1199   // work.
1200   if (mul_const_input_shape.dim_size() <=
1201           static_cast<int>(data_format.size()) &&
1202       TensorShape(mul_const_input_shape).num_elements() == 1) {
1203     return true;
1204   }
1205 
1206   // Otherwise, check the eligibility according to data format.
1207   if (data_format == "NHWC" || data_format == "NDHWC") {
1208     TensorShapeProto new_filter_shape;
1209     if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1210                              &new_filter_shape)) {
1211       return false;
1212     }
1213     if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1214       return false;
1215     }
1216     // Only the last dimension could be larger than one, since broadcasting over
1217     // the last dimension (the output channel) will result in invalid filter.
1218     for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1219       if (mul_const_input_shape.dim(i).size() > 1) return false;
1220     }
1221     return true;
1222   } else if (data_format == "NCHW" || data_format == "NCDHW") {
1223     // TODO(laigd): support NCHW and NCDHW (b/111214513).
1224     return false;
1225   }
1226   return false;
1227 }
1228 
1229 }  // namespace
1230 
1231 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1232 Status ConstantFolding::CreateNodeDef(const string& name,
1233                                       const TensorValue& tensor, NodeDef* node,
1234                                       size_t original_size) {
1235   node->set_name(name);
1236   node->set_op("Const");
1237 
1238   AttrValue attr_type;
1239   attr_type.set_type(tensor->dtype());
1240   node->mutable_attr()->insert({"dtype", attr_type});
1241 
1242   AttrValue attr_tensor;
1243   TensorProto* t = attr_tensor.mutable_tensor();
1244   bool optimized = false;
1245   size_t encoded_size;
1246   // Use the packed representation whenever possible to avoid generating large
1247   // graphdefs. Moreover, avoid repeating the last values if they're equal.
1248   if (tensor->NumElements() > 4) {
1249 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE)                      \
1250   {                                                                            \
1251     const auto* val_ptr = tensor->flat<TYPE>().data();                         \
1252     auto last = *val_ptr;                                                      \
1253     int64 last_index = 0;                                                      \
1254     for (int64 i = 0; i < tensor->NumElements(); ++i) {                        \
1255       TYPE cur = *val_ptr++;                                                   \
1256       if (PackedValuesNotEqual(cur, last)) {                                   \
1257         last = cur;                                                            \
1258         last_index = i;                                                        \
1259       }                                                                        \
1260     }                                                                          \
1261     encoded_size = (last_index + 1) * sizeof(FIELDTYPE);                       \
1262     if (encoded_size < kint32max) {                                            \
1263       optimized = true;                                                        \
1264       t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1);                 \
1265       const auto* src_ptr = tensor->flat<TYPE>().data();                       \
1266       auto* dst_ptr =                                                          \
1267           t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1268       std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr);                   \
1269     }                                                                          \
1270   }                                                                            \
1271   break
1272 
1273     switch (tensor->dtype()) {
1274       case DT_FLOAT:
1275         POPULATE_TENSOR_PROTO(tensor, t, float, float);
1276       case DT_DOUBLE:
1277         POPULATE_TENSOR_PROTO(tensor, t, double, double);
1278       case DT_INT64:
1279         POPULATE_TENSOR_PROTO(tensor, t, int64_t, int64);
1280       case DT_UINT64:
1281         POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1282       case DT_INT32:
1283         POPULATE_TENSOR_PROTO(tensor, t, int32_t, int);
1284       case DT_UINT32:
1285         POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1286       case DT_INT16:
1287         POPULATE_TENSOR_PROTO(tensor, t, int16_t, int);
1288       case DT_UINT16:
1289         POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1290       case DT_INT8:
1291         POPULATE_TENSOR_PROTO(tensor, t, int8_t, int);
1292       case DT_UINT8:
1293         POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1294       case DT_BOOL:
1295         POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1296       default:
1297         /* Do nothing. */
1298         break;
1299     }
1300   }
1301   if (optimized) {
1302     // Also specify type and shape.
1303     t->set_dtype(tensor->dtype());
1304     tensor->shape().AsProto(t->mutable_tensor_shape());
1305   } else {
1306     // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1307     // DT_QUINT8
1308     tensor->AsProtoTensorContent(t);
1309     encoded_size = t->tensor_content().size();
1310   }
1311   node->mutable_attr()->insert({"value", attr_tensor});
1312 
1313   if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1314     return errors::InvalidArgument(
1315         strings::StrCat("Can't fold ", name, ", its size would be too large (",
1316                         encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1317   }
1318   return Status::OK();
1319 }
1320 
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1321 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1322                                      const TensorVector& inputs,
1323                                      TensorVector* output) const {
1324   return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1325                                               resource_mgr_.get(), output);
1326 }
1327 
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1328 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1329                                             std::vector<NodeDef>* outputs,
1330                                             bool* result_too_large) {
1331   TensorVector inputs;
1332   TensorVector output_tensors;
1333   auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1334     for (const auto& input : inputs) {
1335       delete input.tensor;
1336     }
1337     for (const auto& output : output_tensors) {
1338       if (output.tensor) {
1339         delete output.tensor;
1340       }
1341     }
1342   });
1343 
1344   size_t total_inputs_size = 0;
1345   for (const auto& input : node.input()) {
1346     const TensorId input_tensor = ParseTensorName(input);
1347     if (input_tensor.index() < 0) {
1348       // Control dependency
1349       break;
1350     }
1351     const NodeDef* input_node = node_map_->GetNode(input);
1352     if (!IsReallyConstant(*input_node)) {
1353       return Status(error::INVALID_ARGUMENT,
1354                     strings::StrCat("Can't fold ", node.name(), ", its ", input,
1355                                     " isn't constant"));
1356     }
1357     TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1358     const TensorProto& raw_val = input_node->attr().at("value").tensor();
1359     if (raw_val.dtype() == DT_INVALID) {
1360       return Status(
1361           error::INVALID_ARGUMENT,
1362           strings::StrCat("A tensor in the input node, with TensorId of ",
1363                           input_tensor.ToString(),
1364                           " has a dtype of DT_INVALID."));
1365     }
1366     Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1367     if (!value->FromProto(raw_val)) {
1368       delete (value);
1369       return errors::InvalidArgument("Unable to make Tensor from proto for ",
1370                                      node.name(), " with shape ",
1371                                      raw_val.tensor_shape().DebugString());
1372     }
1373     inputs.emplace_back(value);
1374     total_inputs_size += value->TotalBytes();
1375   }
1376 
1377   TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1378   if (output_tensors.empty()) {
1379     return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1380   }
1381 
1382   outputs->resize(output_tensors.size());
1383   for (size_t i = 0; i < output_tensors.size(); i++) {
1384     string node_name = OptimizedNodeName(node, "-folded");
1385     if (output_tensors.size() > 1) {
1386       node_name = strings::StrCat(node_name, "-", i);
1387     }
1388     if (output_tensors[i].tensor) {
1389       Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1390                                total_inputs_size);
1391       if (!s.ok()) {
1392         *result_too_large = true;
1393         return s;
1394       }
1395     } else {
1396       // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1397       // switch that's not selected by the switch predicate).
1398       outputs->at(i) = NodeDef();
1399     }
1400   }
1401   return Status::OK();
1402 }
1403 
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1404 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1405   // Merge nodes are special, in the sense that they execute as soon as one of
1406   // their input is ready. We can therefore fold a merge node iff it has at
1407   // least one constant input without control dependency.
1408   // We still need to ensure that the nodes in the fanin of the merge node are
1409   // scheduled. We'll therefore add a control dependency from the merge node
1410   // to the folded constant. We end up with:
1411   //  * the merge node and its inputs are preserved as is
1412   //  * a new constant node C1, driven by the merge node through a control
1413   //  dependency, initialized to the value of the folded input
1414   //  * a new constant node C2, driven by the merge node through a control
1415   //  dependency, initialized to the index of the folded input
1416   //  * the fanout of the merge nodes is rewired to be driven by either C1 or
1417   //  C2.
1418   for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1419     const auto& input = node->input(input_index);
1420     if (IsControlInput(input)) {
1421       // Try the next input.
1422       continue;
1423     }
1424     NodeDef* input_node = node_map_->GetNode(input);
1425     if (!IsReallyConstant(*input_node)) {
1426       continue;
1427     }
1428     bool valid_input = true;
1429     for (const string& fanin_of_input : input_node->input()) {
1430       if (IsControlInput(fanin_of_input)) {
1431         valid_input = false;
1432         break;
1433       }
1434     }
1435     if (!valid_input) {
1436       // Try the next input
1437       continue;
1438     }
1439 
1440     string const_out_name = OptimizedNodeName(*node, "_const");
1441     string const_index_name = OptimizedNodeName(*node, "_index");
1442     if (node_map_->GetNode(const_out_name) ||
1443         node_map_->GetNode(const_index_name)) {
1444       // Intended name already exists.
1445       return errors::AlreadyExists(
1446           strings::StrCat(const_out_name, " or ", const_index_name,
1447                           " already present in the graph"));
1448     }
1449 
1450     NodeDef* const_out = output_graph->add_node();
1451     *const_out = *input_node;
1452     const_out->set_name(const_out_name);
1453     const_out->set_device(node->device());
1454     *const_out->add_input() = AsControlDependency(*node);
1455     node_map_->AddNode(const_out->name(), const_out);
1456     node_map_->AddOutput(node->name(), const_out->name());
1457 
1458     NodeDef* const_index = output_graph->add_node();
1459     const_index->set_op("Const");
1460     Tensor index(DT_INT32, TensorShape({}));
1461     index.flat<int32>()(0) = input_index;
1462     (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1463     index.AsProtoTensorContent(
1464         (*const_index->mutable_attr())["value"].mutable_tensor());
1465     const_index->set_name(const_index_name);
1466     const_index->set_device(node->device());
1467     *const_index->add_input() = AsControlDependency(*node);
1468     node_map_->AddNode(const_index->name(), const_index);
1469     node_map_->AddOutput(node->name(), const_index->name());
1470 
1471     // We make a copy because we mutate the nodes.
1472     auto outputs = node_map_->GetOutputs(node->name());
1473     for (NodeDef* output : outputs) {
1474       for (int i = 0; i < output->input_size(); i++) {
1475         int port;
1476         string node_name = ParseNodeName(output->input(i), &port);
1477         if (node_name == node->name()) {
1478           if (port == 0) {
1479             *output->mutable_input(i) = const_out->name();
1480             node_map_->AddOutput(const_out->name(), output->name());
1481           } else if (port == 1) {
1482             *output->mutable_input(i) = const_index->name();
1483             node_map_->AddOutput(const_index->name(), output->name());
1484           } else {
1485             // This is a control dependency (or an invalid edge since the
1486             // merge node has only 2 outputs): preserve them.
1487           }
1488         }
1489       }
1490     }
1491     return Status::OK();
1492   }
1493   return Status::OK();
1494 }
1495 
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1496 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1497                                  bool* result_too_large) {
1498   *result_too_large = false;
1499   if (IsMerge(*node)) {
1500     return FoldMergeNode(node, output_graph);
1501   }
1502 
1503   std::vector<NodeDef> const_nodes;
1504   TF_RETURN_IF_ERROR(
1505       EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1506   VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1507 
1508   NodeDef* constant_output = nullptr;
1509   for (int i = 0, end = const_nodes.size(); i < end; i++) {
1510     NodeDef* const_node = &const_nodes[i];
1511     VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1512     if (const_node->name().empty()) {
1513       // Dead output: we can't create a constant to encode its value, so we'll
1514       // just skip it. We'll preserve the edges that originate from that
1515       // output below to preserve the overall behavior of the graph wrt dead
1516       // edges.
1517       continue;
1518     }
1519 
1520     // Returns `true` iff `const_node` already has control input named `input`.
1521     const auto is_duplicate_control_input = [&](const string& input) -> bool {
1522       auto it = absl::c_find(const_node->input(), input);
1523       return it != const_node->input().end();
1524     };
1525 
1526     // Forward control dependencies.
1527     for (const string& input : node->input()) {
1528       // Forward control dependencies from folded node.
1529       if (IsControlInput(input)) {
1530         if (!is_duplicate_control_input(input)) {
1531           *const_node->add_input() = input;
1532         }
1533       }
1534 
1535       // Forward control dependencies from constant inputs to folded node.
1536       if (!IsControlInput(input)) {
1537         NodeDef* input_node = node_map_->GetNode(input);
1538         for (const string& fanin_of_input : input_node->input()) {
1539           if (!is_duplicate_control_input(fanin_of_input)) {
1540             *const_node->add_input() = fanin_of_input;
1541           }
1542         }
1543       }
1544     }
1545 
1546     // We rewrite the existing node if it only has a single output, and
1547     // create new nodes otherwise.
1548     if (const_nodes.size() == 1) {
1549       node->set_op("Const");
1550       // Note we need to clear the inputs in NodeMap before we clear the inputs
1551       // in the node, otherwise NodeMap would see empty inputs and effectively
1552       // does nothing.
1553       node_map_->RemoveInputs(node->name());
1554       node->clear_input();
1555       *node->mutable_input() = const_node->input();
1556       for (const auto& input : node->input()) {
1557         node_map_->AddOutput(NodeName(input), node->name());
1558       }
1559       *node->mutable_attr() = const_node->attr();
1560       break;
1561     } else {
1562       if (node_map_->GetNode(const_node->name())) {
1563         // Intended name already exists.
1564         return errors::AlreadyExists(strings::StrCat(
1565             const_node->name(), " already present in the graph"));
1566       }
1567       NodeDef* added_node = output_graph->add_node();
1568       *added_node = *const_node;
1569       added_node->set_device(node->device());
1570       node_map_->AddNode(added_node->name(), added_node);
1571       for (const auto& input : added_node->input()) {
1572         node_map_->AddOutput(NodeName(input), added_node->name());
1573       }
1574       // All the constant nodes encoding output values have the same control
1575       // dependencies (since these are the control dependencies of the node
1576       // we're trying to fold). Record one such constant node.
1577       constant_output = added_node;
1578     }
1579   }
1580 
1581   if (const_nodes.size() > 1) {
1582     // We make a copy because we mutate the nodes.
1583     auto outputs = node_map_->GetOutputs(node->name());
1584     for (NodeDef* output : outputs) {
1585       for (int i = 0; i < output->input_size(); i++) {
1586         int port;
1587         string node_name = ParseNodeName(output->input(i), &port);
1588         if (node_name == node->name()) {
1589           if (port < 0) {
1590             // Propagate control dependencies if possible. If not, we'll just
1591             // preserve the existing control dependencies.
1592             if (constant_output != nullptr) {
1593               node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1594                                      constant_output->name());
1595               *output->mutable_input(i) = AsControlDependency(*constant_output);
1596             }
1597           } else if (port < static_cast<int>(const_nodes.size()) &&
1598                      !const_nodes[port].name().empty()) {
1599             // Replace alive outputs with the corresponding constant.
1600             node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1601                                    const_nodes[port].name());
1602             *output->mutable_input(i) = const_nodes[port].name();
1603           } else {
1604             // Leave this edge alone.
1605             VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1606                     << "[" << node->op() << "] to " << output->name() << ":"
1607                     << i << "[" << output->op() << "]";
1608           }
1609         }
1610       }
1611     }
1612     outputs = node_map_->GetOutputs(node->name());
1613     if (outputs.empty() && has_fetch_ &&
1614         nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1615       node_map_->RemoveInputs(node->name());
1616       node->clear_input();
1617     }
1618   }
1619   return Status::OK();
1620 }
1621 
FoldGraph(const GraphProperties & properties,GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1622 Status ConstantFolding::FoldGraph(
1623     const GraphProperties& properties, GraphDef* output,
1624     absl::flat_hash_set<string>* nodes_to_not_simplify) {
1625   std::unordered_set<string> processed_nodes;
1626   std::deque<NodeDef*> queue;
1627   for (int i = 0; i < graph_->node_size(); i++) {
1628     bool foldable = IsFoldable(graph_->node(i), &properties);
1629     VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
1630     if (foldable) {
1631       queue.push_back(graph_->mutable_node(i));
1632     }
1633   }
1634   while (!queue.empty()) {
1635     NodeDef* node = queue.front();
1636     queue.pop_front();
1637     if (processed_nodes.count(node->name())) {
1638       continue;
1639     }
1640     // We need to record a copy of output nodes before FoldNode() modifies it.
1641     // We also need to ensure that the fanout is sorted deterministically.
1642     std::vector<NodeDef*> fanout =
1643         node_map_->GetOutputsOrderedByNodeName(node->name());
1644     bool result_too_large = false;
1645     Status s = FoldNode(node, output, &result_too_large);
1646     processed_nodes.insert(node->name());
1647     if (!s.ok()) {
1648       VLOG(1) << "Failed to fold node " << node->DebugString()
1649               << "\nError message: " << s;
1650       if (result_too_large) {
1651         nodes_to_not_simplify->emplace(node->name());
1652       }
1653     } else {
1654       for (auto& output : fanout) {
1655         if (IsFoldable(*output, &properties)) {
1656           queue.push_back(output);
1657         }
1658       }
1659     }
1660   }
1661 
1662   // Delete the newly created nodes that don't feed anything.
1663   std::vector<int> nodes_to_delete;
1664   for (int i = 0; i < output->node_size(); i++) {
1665     const auto& fanout = node_map_->GetOutputs(output->node(i).name());
1666     if (fanout.empty()) nodes_to_delete.push_back(i);
1667   }
1668   EraseNodesFromGraph(std::move(nodes_to_delete), output);
1669 
1670   for (int i = 0; i < graph_->node_size(); ++i) {
1671     NodeDef* node = graph_->mutable_node(i);
1672     // If no fetch nodes is provided, we conservatively
1673     // move all nodes in the original graph to the output, in case users need
1674     // to fetch their values.
1675     const auto& fanout = node_map_->GetOutputs(node->name());
1676     if (!fanout.empty() || !has_fetch_ ||
1677         nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1678       *(output->add_node()) = std::move(*node);
1679     }
1680   }
1681   return Status::OK();
1682 }
1683 
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1684 bool ConstantFolding::IsSimplifiableReshape(
1685     const NodeDef& node, const GraphProperties& properties) const {
1686   if (!IsReshape(node)) {
1687     return false;
1688   }
1689   CHECK_LE(2, node.input_size());
1690   const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1691   if (!IsReallyConstant(*new_shape)) {
1692     return false;
1693   }
1694   TensorVector outputs;
1695   auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1696     for (const auto& output : outputs) {
1697       delete output.tensor;
1698     }
1699   });
1700 
1701   Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1702   if (!s.ok()) {
1703     return false;
1704   }
1705   CHECK_EQ(1, outputs.size());
1706 
1707   const std::vector<OpInfo::TensorProperties>& props =
1708       properties.GetInputProperties(node.name());
1709   if (props.empty()) {
1710     return false;
1711   }
1712   const OpInfo::TensorProperties& prop = props[0];
1713   if (prop.dtype() == DT_INVALID) {
1714     return false;
1715   }
1716   const PartialTensorShape shape(prop.shape());
1717   if (!shape.IsFullyDefined()) {
1718     return false;
1719   }
1720 
1721   PartialTensorShape new_dims;
1722   if (outputs[0]->dtype() == DT_INT32) {
1723     std::vector<int32> shp;
1724     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1725       int32_t dim = outputs[0]->flat<int32>()(i);
1726       shp.push_back(dim);
1727     }
1728     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1729   } else {
1730     std::vector<int64> shp;
1731     for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1732       int64_t dim = outputs[0]->flat<int64>()(i);
1733       shp.push_back(dim);
1734     }
1735     TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1736   }
1737 
1738   return shape.IsCompatibleWith(new_dims);
1739 }
1740 
1741 #define IS_VALUE_CASE(DTYPE, VALUE)                   \
1742   case DTYPE:                                         \
1743     return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1744         node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1745 
1746 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1747 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1748 
IsOnes(const NodeDef & node) const1749 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1750   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1751     return false;
1752   }
1753   if (IsOnesLike(node)) return true;
1754   if (IsZerosLike(node)) return false;
1755   if (node.op() == "Fill") {
1756     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1757     return values != nullptr && IsOnes(*values);
1758   }
1759   if (node.op() != "Const") return false;
1760   if (node.attr().count("dtype") == 0) return false;
1761   const auto dtype = node.attr().at("dtype").type();
1762   switch (dtype) {
1763     IS_ONES_CASE(DT_BOOL);
1764     IS_ONES_CASE(DT_HALF);
1765     IS_ONES_CASE(DT_BFLOAT16);
1766     IS_ONES_CASE(DT_FLOAT);
1767     IS_ONES_CASE(DT_DOUBLE);
1768     IS_ONES_CASE(DT_COMPLEX64);
1769     IS_ONES_CASE(DT_COMPLEX128);
1770     IS_ONES_CASE(DT_UINT8);
1771     IS_ONES_CASE(DT_INT8);
1772     IS_ONES_CASE(DT_UINT16);
1773     IS_ONES_CASE(DT_INT16);
1774     IS_ONES_CASE(DT_INT32);
1775     IS_ONES_CASE(DT_INT64);
1776     IS_ONES_CASE(DT_QINT32);
1777     IS_ONES_CASE(DT_QINT16);
1778     IS_ONES_CASE(DT_QUINT16);
1779     IS_ONES_CASE(DT_QINT8);
1780     IS_ONES_CASE(DT_QUINT8);
1781     default:
1782       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1783       return false;
1784   }
1785   return false;
1786 }
1787 
IsZeros(const NodeDef & node) const1788 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1789   if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1790     return false;
1791   }
1792   if (IsOnesLike(node)) return false;
1793   if (IsZerosLike(node)) return true;
1794   if (node.op() == "Fill") {
1795     NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1796     return values != nullptr && IsZeros(*values);
1797   }
1798   if (!IsConstant(node)) return false;
1799   if (node.attr().count("dtype") == 0) return false;
1800   const auto dtype = node.attr().at("dtype").type();
1801   switch (dtype) {
1802     IS_ZEROS_CASE(DT_BOOL);
1803     IS_ZEROS_CASE(DT_HALF);
1804     IS_ZEROS_CASE(DT_BFLOAT16);
1805     IS_ZEROS_CASE(DT_FLOAT);
1806     IS_ZEROS_CASE(DT_DOUBLE);
1807     IS_ZEROS_CASE(DT_COMPLEX64);
1808     IS_ZEROS_CASE(DT_COMPLEX128);
1809     IS_ZEROS_CASE(DT_UINT8);
1810     IS_ZEROS_CASE(DT_INT8);
1811     IS_ZEROS_CASE(DT_UINT16);
1812     IS_ZEROS_CASE(DT_INT16);
1813     IS_ZEROS_CASE(DT_INT32);
1814     IS_ZEROS_CASE(DT_INT64);
1815     IS_ZEROS_CASE(DT_QINT32);
1816     IS_ZEROS_CASE(DT_QINT16);
1817     IS_ZEROS_CASE(DT_QUINT16);
1818     IS_ZEROS_CASE(DT_QINT8);
1819     IS_ZEROS_CASE(DT_QUINT8);
1820     default:
1821       VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1822       return false;
1823   }
1824   return false;
1825 }
1826 
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1827 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1828     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1829     GraphDef* graph) {
1830   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1831   if (dtype == DT_INVALID) {
1832     return false;
1833   }
1834   const PartialTensorShape shape(
1835       properties.GetOutputProperties(node->name())[0].shape());
1836   if (!shape.IsFullyDefined()) {
1837     return false;
1838   }
1839   // Create constant node with shape.
1840   const string const_name = OptimizedNodeName(
1841       *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1842   if (node_map_->GetNode(const_name) != nullptr) {
1843     return false;
1844   }
1845 
1846   Tensor shape_t;
1847   if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1848     return false;
1849   }
1850   NodeDef tmp;
1851   if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1852     return false;
1853   }
1854   NodeDef* const_node = graph->add_node();
1855   const_node->Swap(&tmp);
1856   const_node->set_device(node->device());
1857   node_map_->AddNode(const_name, const_node);
1858   for (int i = 0; i < node->input_size(); ++i) {
1859     if (i != input_to_broadcast) {
1860       // Add a control input on the unused input.
1861       string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1862                                              node_map_.get());
1863       *const_node->add_input() = ctrl_dep;
1864       node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1865     }
1866   }
1867 
1868   // Rewrite `node` in-place to BroadcastTo.
1869   node->set_op("BroadcastTo");
1870   EraseRegularNodeAttributes(node);
1871   (*node->mutable_attr())["T"].set_type(dtype);
1872   (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1873   // Set the designated input to BroadcastTo.
1874   node->mutable_input()->SwapElements(0, input_to_broadcast);
1875   // Keep all other inputs as control dependencies.
1876   for (int i = 1; i < node->input_size(); ++i) {
1877     if (IsControlInput(node->input(i))) {
1878       break;
1879     }
1880     const string ctrl_dep =
1881         AddControlDependency(node->input(i), graph, node_map_.get());
1882     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1883     node->set_input(i, ctrl_dep);
1884   }
1885   // Add the shape argument.
1886   *node->add_input() = const_node->name();
1887   node_map_->AddOutput(const_name, node->name());
1888   node->mutable_input()->SwapElements(1, node->input_size() - 1);
1889   return true;
1890 }
1891 
1892 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1893 void ConstantFolding::ReplaceOperationWithIdentity(
1894     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1895     GraphDef* graph) {
1896   if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1897   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1898   if (dtype == DT_INVALID) return;
1899 
1900   node->set_op("Identity");
1901   EraseRegularNodeAttributes(node);
1902   (*node->mutable_attr())["T"].set_type(dtype);
1903   // Propagate the designated input through the identity.
1904   node->mutable_input()->SwapElements(0, input_to_forward);
1905   // Add all other inputs as control dependencies.
1906   for (int i = 1; i < node->input_size(); ++i) {
1907     if (IsControlInput(node->input(i))) {
1908       break;
1909     }
1910     const string ctrl_dep =
1911         AddControlDependency(node->input(i), graph, node_map_.get());
1912     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1913     node->set_input(i, ctrl_dep);
1914   }
1915   graph_modified_ = true;
1916 }
1917 
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1918 void ConstantFolding::ReplaceOperationWithSnapshot(
1919     int input_to_forward, const GraphProperties& properties, NodeDef* node,
1920     GraphDef* graph) {
1921   // If the graph contains no ops that mutate their inputs, we can
1922   // use Identity instead of Snapshot.
1923   if (!graph_contains_assign_or_inplace_op_) {
1924     ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1925     return;
1926   }
1927 
1928   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1929   if (dtype == DT_INVALID) return;
1930 
1931   node->set_op("Snapshot");
1932   EraseRegularNodeAttributes(node);
1933   (*node->mutable_attr())["T"].set_type(dtype);
1934   // Propagate the designated input through the Snapshot.
1935   node->mutable_input()->SwapElements(0, input_to_forward);
1936   // Add all other inputs as control dependencies.
1937   for (int i = 1; i < node->input_size(); ++i) {
1938     if (IsControlInput(node->input(i))) {
1939       break;
1940     }
1941     const string ctrl_dep =
1942         AddControlDependency(node->input(i), graph, node_map_.get());
1943     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1944     node->set_input(i, ctrl_dep);
1945   }
1946   graph_modified_ = true;
1947 }
1948 
1949 // Replace a node with NoOp. Change all inputs to control dependencies.
1950 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1951 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1952                                                GraphProperties* properties,
1953                                                GraphDef* graph) {
1954   if (HasRegularOutputs(*node, *node_map_)) return;
1955   node->set_op("NoOp");
1956   EraseRegularNodeAttributes(node);
1957   EraseNodeOutputAttributes(node);
1958   // Erase attributes that describe output properties.
1959   properties->ClearOutputProperties(node->name());
1960   // Change all inputs to control dependencies.
1961   for (int i = 0; i < node->input_size(); ++i) {
1962     if (IsControlInput(node->input(i))) {
1963       break;
1964     }
1965     const string ctrl_dep =
1966         AddControlDependency(node->input(i), graph, node_map_.get());
1967     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1968     node->set_input(i, ctrl_dep);
1969   }
1970   DedupControlInputs(node);
1971   graph_modified_ = true;
1972 }
1973 
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1974 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
1975     int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1976     GraphDef* graph) {
1977   if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
1978                                        graph)) {
1979     return;
1980   }
1981   graph_modified_ = true;
1982 }
1983 
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1984 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1985                                                         GraphDef* graph) {
1986   node->set_op("Reciprocal");
1987   node->mutable_input()->SwapElements(0, 1);
1988   const string ctrl_dep =
1989       AddControlDependency(node->input(1), graph, node_map_.get());
1990   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1991   node->set_input(1, ctrl_dep);
1992   graph_modified_ = true;
1993 }
1994 
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1995 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1996                                                            GraphDef* graph) {
1997   node->set_op("Neg");
1998   node->mutable_input()->SwapElements(0, 1);
1999   const string ctrl_dep =
2000       AddControlDependency(node->input(1), graph, node_map_.get());
2001   node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2002   node->set_input(1, ctrl_dep);
2003   graph_modified_ = true;
2004 }
2005 
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)2006 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
2007                                                            TensorProto* value,
2008                                                            NodeDef* node,
2009                                                            GraphDef* graph) {
2010   if (dtype == DT_VARIANT) return Status::OK();
2011   node->set_op("Const");
2012   EraseRegularNodeAttributes(node);
2013   (*node->mutable_attr())["dtype"].set_type(dtype);
2014   (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
2015   // Convert all inputs to control dependencies.
2016   for (int i = 0; i < node->input_size(); ++i) {
2017     if (IsControlInput(node->input(i))) {
2018       break;
2019     }
2020     const string ctrl_dep =
2021         AddControlDependency(node->input(i), graph, node_map_.get());
2022     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2023     node->set_input(i, ctrl_dep);
2024   }
2025   DedupControlInputs(node);
2026   graph_modified_ = true;
2027   return Status::OK();
2028 }
2029 
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2030 Status ConstantFolding::ReplaceOperationWithConstant(
2031     double value, const GraphProperties& properties,
2032     const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2033   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2034   if (dtype == DT_VARIANT) return Status::OK();
2035   AttrValue tensor_attr;
2036   Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2037   if (!s.ok()) {
2038     // Fail gracefully without mutating the graph.
2039     VLOG(1) << "Failed to replace node " << node->name() << " of type "
2040             << DataTypeString(dtype) << " with constant tensor of value "
2041             << value;
2042     return Status::OK();
2043   }
2044   return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2045                                             node, graph);
2046 }
2047 
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2048 Status ConstantFolding::SimplifyGraph(
2049     bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
2050     absl::flat_hash_set<string>* nodes_to_not_simplify) {
2051   for (int i = 0; i < optimized_graph->node_size(); ++i) {
2052     NodeDef* node = optimized_graph->mutable_node(i);
2053     // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2054     // generalize to only restrict certain simplifications.
2055     if (nodes_to_not_simplify->find(node->name()) ==
2056         nodes_to_not_simplify->end()) {
2057       if (HasTPUAttributes(*node)) {
2058         nodes_to_not_simplify->insert(node->name());
2059         continue;
2060       }
2061 
2062       TF_RETURN_IF_ERROR(
2063           SimplifyNode(use_shape_info, node, optimized_graph, properties));
2064     }
2065   }
2066   return Status::OK();
2067 }
2068 
2069 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2070   TF_RETURN_IF_ERROR(EXPR);               \
2071   if (graph_modified_) return Status::OK()
2072 
2073 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2074   graph_modified_ = EXPR;                \
2075   if (graph_modified_) return Status::OK()
2076 
2077 #define RETURN_IF_MODIFIED(EXPR) \
2078   EXPR;                          \
2079   if (graph_modified_) return Status::OK()
2080 
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2081 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
2082                                      GraphDef* optimized_graph,
2083                                      GraphProperties* properties) {
2084   bool graph_modified_cached = graph_modified_;
2085   graph_modified_ = false;
2086 
2087   RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2088   RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2089       *properties, use_shape_info, optimized_graph, node));
2090   RETURN_IF_MODIFIED(
2091       RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2092   RETURN_IF_ERROR_OR_MODIFIED(
2093       RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2094   RETURN_IF_ERROR_OR_MODIFIED(
2095       SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2096   RETURN_IF_ERROR_OR_MODIFIED(
2097       SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2098   RETURN_IF_ERROR_OR_MODIFIED(
2099       SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2100   RETURN_IF_ERROR_OR_MODIFIED(
2101       SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2102   RETURN_IF_MODIFIED(
2103       SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2104   SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2105   SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2106   SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2107   SET_AND_RETURN_IF_MODIFIED(
2108       SimplifyReduction(optimized_graph, *properties, node));
2109   SET_AND_RETURN_IF_MODIFIED(
2110       SimplifyReshape(*properties, use_shape_info, node));
2111   RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2112       *properties, use_shape_info, optimized_graph, node));
2113   SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2114   SET_AND_RETURN_IF_MODIFIED(
2115       ConstantPushDown(properties, optimized_graph, node));
2116   SET_AND_RETURN_IF_MODIFIED(
2117       MulConvPushDown(optimized_graph, node, *properties));
2118   SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2119   SET_AND_RETURN_IF_MODIFIED(
2120       PartialAssocOpConstFolding(optimized_graph, properties, node));
2121   SET_AND_RETURN_IF_MODIFIED(
2122       MergeConcat(use_shape_info, properties, optimized_graph, node));
2123   SET_AND_RETURN_IF_MODIFIED(
2124       PartialConcatConstFolding(optimized_graph, properties, node));
2125   SET_AND_RETURN_IF_MODIFIED(
2126       ConstantPushDownBiasAdd(properties, optimized_graph, node));
2127   SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2128   SET_AND_RETURN_IF_MODIFIED(
2129       SimplifySelect(*properties, optimized_graph, node));
2130   RETURN_IF_MODIFIED(
2131       RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2132 
2133   graph_modified_ = graph_modified_cached;
2134   return Status::OK();
2135 }
2136 
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2137 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2138                                           GraphDef* optimized_graph,
2139                                           NodeDef* node) {
2140   if (node->attr().count("num_split") == 0) return;
2141   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2142     ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2143   }
2144   if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2145     ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2146   }
2147 }
2148 
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2149 Status ConstantFolding::RemoveShuffleOrTranspose(
2150     const GraphProperties& properties, bool use_shape_info,
2151     GraphDef* optimized_graph, NodeDef* node) {
2152   if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2153     return Status::OK();
2154   Tensor permutation_tensor;
2155   if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2156       properties.HasInputProperties(node->name())) {
2157     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2158     std::vector<int> permutation;
2159     for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2160       if (permutation_tensor.dtype() == DT_INT64) {
2161         permutation.push_back(permutation_tensor.vec<int64>()(j));
2162       } else {
2163         permutation.push_back(permutation_tensor.vec<int>()(j));
2164       }
2165     }
2166     int permutation_size = permutation.size();
2167     if (permutation_size != shape.dim_size()) {
2168       // Number of elements in perm should be same as dim_size. Skip if not.
2169       return Status::OK();
2170     }
2171     // The node is replaceable iff
2172     // dim_size == 0 || all dims have size 1 ||
2173     // all dims with > 1 size are not permuted.
2174     bool replaceable = true;
2175     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2176       replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2177     }
2178     if (replaceable) {
2179       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2180     }
2181   }
2182   return Status::OK();
2183 }
2184 
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2185 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2186                                           bool use_shape_info,
2187                                           GraphDef* optimized_graph,
2188                                           NodeDef* node) {
2189   if (use_shape_info && IsRandomShuffle(*node) &&
2190       !properties.GetInputProperties(node->name()).empty()) {
2191     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2192     // The node is replaceable iff
2193     // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2194     if (!shape.unknown_rank() &&
2195         (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2196       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2197     }
2198   }
2199 }
2200 
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2201 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2202                                       bool use_shape_info,
2203                                       GraphDef* optimized_graph,
2204                                       NodeDef* node) {
2205   if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
2206   Tensor axis;
2207   if (properties.HasInputProperties(node->name()) &&
2208       GetTensorFromConstNode(node->input(1), &axis)) {
2209     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2210     if (shape.unknown_rank()) return Status::OK();
2211     std::set<int> target_axes;
2212     for (int j = 0; j < axis.NumElements(); ++j) {
2213       // value of axis can be negative.
2214       if (axis.dtype() == DT_INT64) {
2215         target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2216                            shape.dim_size());
2217       } else {
2218         target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2219                            shape.dim_size());
2220       }
2221     }
2222 
2223     // The node is replaceable iff
2224     // unknown_rank == false &&
2225     // (dim_size == 0 || all dims have size 1 ||
2226     //  all dims with > 1 size are not in target_axes)
2227     bool replaceable = true;
2228     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2229       replaceable &=
2230           shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2231     }
2232     if (replaceable) {
2233       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2234     }
2235   }
2236   return Status::OK();
2237 }
2238 
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2239 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2240                                       bool use_shape_info,
2241                                       GraphDef* optimized_graph,
2242                                       NodeDef* node) {
2243   if (!use_shape_info || !IsSlice(*node)) return Status::OK();
2244   Tensor begin;
2245   Tensor size;
2246   if (properties.HasInputProperties(node->name()) &&
2247       GetTensorFromConstNode(node->input(1), &begin) &&
2248       GetTensorFromConstNode(node->input(2), &size)) {
2249     const auto& input = properties.GetInputProperties(node->name())[0];
2250     // The node is replaceable iff unknown_rank == false &&
2251     // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2252     bool replaceable = !input.shape().unknown_rank();
2253     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2254       if (begin.dtype() == DT_INT32) {
2255         replaceable &= begin.vec<int>()(j) == 0;
2256       } else {
2257         replaceable &= begin.vec<int64>()(j) == 0;
2258       }
2259       if (size.dtype() == DT_INT32) {
2260         replaceable &= (size.vec<int>()(j) == -1 ||
2261                         size.vec<int>()(j) == input.shape().dim(j).size());
2262       } else {
2263         replaceable &= (size.vec<int64>()(j) == -1 ||
2264                         size.vec<int64>()(j) == input.shape().dim(j).size());
2265       }
2266     }
2267     if (replaceable) {
2268       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2269     }
2270   }
2271   return Status::OK();
2272 }
2273 
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2274 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2275                                              bool use_shape_info,
2276                                              GraphDef* optimized_graph,
2277                                              NodeDef* node) {
2278   if (use_shape_info && IsStridedSlice(*node) &&
2279       properties.GetInputProperties(node->name()).size() == 4) {
2280     TF_RETURN_IF_ERROR(
2281         CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2282     if (node->attr().at("new_axis_mask").i() != 0 ||
2283         node->attr().at("shrink_axis_mask").i() != 0) {
2284       // Skip nodes with new/shrink axis mask, since they involve dimension
2285       // changes.
2286       return Status::OK();
2287     }
2288     const auto& input = properties.GetInputProperties(node->name())[0];
2289     for (int j = 0; j < input.shape().dim_size(); ++j) {
2290       // Skip if input shape is not fully determined.
2291       if (input.shape().dim(j).size() < 0) {
2292         return Status::OK();
2293       }
2294     }
2295 
2296     std::vector<Tensor> input_tensors(3);
2297     for (int i = 1; i < 4; ++i) {
2298       if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2299         return Status::OK();
2300       }
2301     }
2302 
2303     const Tensor& begin = input_tensors[0];
2304     const Tensor& end = input_tensors[1];
2305     const Tensor& strides = input_tensors[2];
2306 
2307     TF_RETURN_IF_ERROR(
2308         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2309     int begin_mask = node->attr().at("begin_mask").i();
2310     int end_mask = node->attr().at("end_mask").i();
2311     std::set<int> expanded_ellipsis_indices;
2312     int ellipsis_index = -1;
2313     for (int j = 0; j < input.shape().dim_size(); ++j) {
2314       // find the ellipsis_mask. If not found, insert one in the end if
2315       // necessary.
2316       if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2317           (ellipsis_index == -1 && j >= strides.NumElements())) {
2318         ellipsis_index = j;
2319       }
2320       // insert the indices that are immediately after ellipsis_index if
2321       // necessary.
2322       if (ellipsis_index != -1 &&
2323           input.shape().dim_size() >
2324               strides.NumElements() + j - ellipsis_index) {
2325         expanded_ellipsis_indices.insert(j);
2326       }
2327     }
2328 
2329     // The node is replaceable iff unknown_rank == false &&
2330     // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2331     //  && strides == 1) for all dimensions.
2332     bool replaceable = !input.shape().unknown_rank();
2333     for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2334       if (expanded_ellipsis_indices.find(j) !=
2335           expanded_ellipsis_indices.end()) {
2336         // ellipsis_mask is effective on current dimension.
2337         continue;
2338       }
2339       // when we have ellipsis_mask in between, input.shape().dim_size() will
2340       // be greater than strides.NumElements(), since we will insert
2341       // as many as expanded_ellipsis_indices.size() axes during computation.
2342       // We need to subtract this number from j.
2343       int i = j;
2344       int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2345       if (ellipsis_index != -1 &&
2346           j >= ellipsis_index + expanded_ellipsis_indices_size) {
2347         i = j - expanded_ellipsis_indices_size;
2348       }
2349       int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2350                                         : begin.vec<int64>()(i);
2351       int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2352       int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2353                                           : strides.vec<int64>()(i);
2354       replaceable &= (begin_mask & 1 << i || b == 0) &&
2355                      (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2356                      s == 1;
2357     }
2358     if (replaceable) {
2359       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2360     }
2361   }
2362   return Status::OK();
2363 }
2364 
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2365 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2366                                      bool use_shape_info,
2367                                      GraphDef* optimized_graph, NodeDef* node) {
2368   Tensor multiplies;
2369   if (use_shape_info && IsTile(*node) &&
2370       GetTensorFromConstNode(node->input(1), &multiplies)) {
2371     // The node is replaceable iff all values in multiplies are 1.
2372     bool replaceable = true;
2373     if (multiplies.dtype() == DT_INT32) {
2374       for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2375         replaceable &= multiplies.vec<int>()(j) == 1;
2376       }
2377     } else {
2378       for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
2379         replaceable &= multiplies.vec<int64>()(j) == 1;
2380       }
2381     }
2382     if (replaceable) {
2383       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2384     }
2385   }
2386   return Status::OK();
2387 }
2388 
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2389 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2390                                     bool use_shape_info,
2391                                     GraphDef* optimized_graph, NodeDef* node) {
2392   if (!use_shape_info || !IsPad(*node)) return Status::OK();
2393 
2394   Tensor paddings;
2395   if (GetTensorFromConstNode(node->input(1), &paddings)) {
2396     // The node is replaceable iff all values in paddings are 0.
2397     bool replaceable = true;
2398     if (paddings.dtype() == DT_INT32) {
2399       const auto flatten = paddings.flat<int32>();
2400       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2401         replaceable &= flatten(j) == 0;
2402       }
2403     } else {
2404       const auto flatten = paddings.flat<int64>();
2405       for (int j = 0; replaceable && j < flatten.size(); ++j) {
2406         replaceable &= flatten(j) == 0;
2407       }
2408     }
2409     if (replaceable) {
2410       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2411     }
2412   }
2413   return Status::OK();
2414 }
2415 
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2416 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2417                                       bool use_shape_info,
2418                                       GraphDef* optimized_graph,
2419                                       NodeDef* node) {
2420   if (use_shape_info && IsSqueeze(*node) &&
2421       !properties.GetInputProperties(node->name()).empty()) {
2422     // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2423     // error to squeeze a dimension that is not 1, so we only need to check
2424     // whether the input has > 1 size for each dimension.
2425     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2426     // The node is replaceable iff
2427     // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2428     bool replaceable = !shape.unknown_rank();
2429     for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2430       replaceable &= shape.dim(j).size() > 1;
2431     }
2432     if (replaceable) {
2433       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2434     }
2435   }
2436 }
2437 
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2438 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2439   const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2440   if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2441       node_map_->NodeExists(axis_node_name)) {
2442     return false;
2443   }
2444 
2445   // It's unsafe to add a control dependency on the feed node, because it might
2446   // have been never executed otherwiwise.
2447   if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2448     return false;
2449   }
2450 
2451   // Create constant axis node.
2452   Tensor axis_t(DT_INT32, TensorShape({}));
2453   const int axis =
2454       node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2455   NodeDef new_node;
2456   if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2457       !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2458     return false;
2459   }
2460   NodeDef* axis_node = optimized_graph->add_node();
2461   *axis_node = std::move(new_node);
2462   axis_node->set_name(axis_node_name);
2463   node_map_->AddNode(axis_node->name(), axis_node);
2464   // Add a control dependency to make sure axis_node is in the right frame.
2465   const string ctrl_dep = ConstantFolding::AddControlDependency(
2466       node->input(0), optimized_graph, node_map_.get());
2467   axis_node->add_input(ctrl_dep);
2468   axis_node->set_device(node->device());
2469   node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2470   node->set_op("ExpandDims");
2471   if (node->attr().count("axis") != 0) {
2472     node->mutable_attr()->erase("axis");
2473   }
2474   if (node->attr().count("N") != 0) {
2475     node->mutable_attr()->erase("N");
2476   }
2477   (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2478   node->add_input(axis_node->name());
2479   node_map_->AddOutput(axis_node->name(), node->name());
2480   if (node->input_size() > 2) {
2481     node->mutable_input()->SwapElements(1, node->input_size() - 1);
2482   }
2483   return true;
2484 }
2485 
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2486 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2487   if (node->op() != "Case") return false;
2488   const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2489   if (output_idx_node == nullptr ||
2490       !CheckAttrExists(*output_idx_node, "value").ok()) {
2491     return false;
2492   }
2493   Tensor output_idx_t;
2494   if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2495     return false;
2496   int output_idx = output_idx_t.scalar<int>()();
2497   const auto& func_list = node->attr().at("branches").list();
2498   if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2499   NodeDef call_node = *node;
2500   call_node.set_op("PartitionedCall");
2501   call_node.clear_input();
2502   for (int i = 1; i < node->input_size(); ++i) {
2503     call_node.add_input(node->input(i));
2504   }
2505   auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2506   *new_func = func_list.func(output_idx);
2507 
2508   // Move the output shape of the branch to _output_shapes if it is known.
2509   const auto& output_shape_list =
2510       (*node->mutable_attr())["output_shapes"].list();
2511   if (output_shape_list.shape_size() > output_idx) {
2512     TensorShapeProto* new_output_shape =
2513         (*call_node.mutable_attr())["_output_shapes"]
2514             .mutable_list()
2515             ->add_shape();
2516     *new_output_shape =
2517         std::move(node->attr().at("output_shapes").list().shape(output_idx));
2518   }
2519 
2520   call_node.mutable_attr()->erase("output_shapes");
2521   call_node.mutable_attr()->erase("branches");
2522 
2523   *node = std::move(call_node);
2524   return true;
2525 }
2526 
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2527 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2528                                      GraphDef* optimized_graph, NodeDef* node) {
2529   if (!IsSelect(*node)) return false;
2530   const std::vector<OpInfo::TensorProperties>& input_props =
2531       properties.GetInputProperties(node->name());
2532   if (input_props.size() < 3) return false;
2533   const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2534   const bool is_all_true = IsOnes(*predicate_node);
2535   const bool is_all_false = IsZeros(*predicate_node);
2536   if (!is_all_true && !is_all_false) {
2537     return false;
2538   }
2539   const int live_input_idx = is_all_true ? 1 : 2;
2540   const int ignored_input_idx = is_all_true ? 2 : 1;
2541   const TensorShapeProto& predicate_shape = input_props[0].shape();
2542   const bool predicate_is_scalar =
2543       !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2544   if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2545       (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2546        predicate_is_scalar)) {
2547     // Replace node with Identity if no broadcasting is involved.
2548     node->set_op("Identity");
2549     *node->mutable_input(0) =
2550         AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2551     *node->mutable_input(ignored_input_idx) = AddControlDependency(
2552         node->input(ignored_input_idx), optimized_graph, node_map_.get());
2553     node->mutable_input()->SwapElements(0, live_input_idx);
2554   } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2555                                               optimized_graph)) {
2556     return false;
2557   }
2558   DedupControlInputs(node);
2559   return true;
2560 }
2561 
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2562 void ConstantFolding::RemoveRedundantVariableUpdates(
2563     GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2564   static const absl::flat_hash_set<string>* kVariableReadOps =
2565       new absl::flat_hash_set<string>{"AssignAddVariableOp",
2566                                       "AssignSubVariableOp",
2567                                       "AssignAdd",
2568                                       "AssignSub",
2569                                       "ScatterAdd",
2570                                       "ScatterSub",
2571                                       "ScatterMul",
2572                                       "ScatterDiv",
2573                                       "ScatterNdAdd",
2574                                       "ScatterNdSub",
2575                                       "ScatterNdMul",
2576                                       "ScatterNdDiv",
2577                                       "ResourceScatterAdd",
2578                                       "ResourceScatterSub",
2579                                       "ResourceScatterMul",
2580                                       "ResourceScatterDiv",
2581                                       "ResourceScatterNdAdd",
2582                                       "ResourceScatterNdSub",
2583                                       "ResourceScatterNdMul",
2584                                       "ResourceScatterNdDiv"};
2585   if (kVariableReadOps == nullptr ||
2586       kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2587     return;
2588   const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2589   const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2590   if (delta_node == nullptr) return;
2591   const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2592                              absl::StrContains(node->op(), "Sub");
2593   if ((is_add_or_sub && IsZeros(*delta_node)) ||
2594       (!is_add_or_sub && IsOnes(*delta_node))) {
2595     VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2596     if (absl::StrContains(node->op(), "Variable") ||
2597         absl::StrContains(node->op(), "Resource")) {
2598       ReplaceOperationWithNoOp(node, properties, optimized_graph);
2599     } else {
2600       ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2601                                    optimized_graph);
2602     }
2603   }
2604 }
2605 
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2606 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2607                                              NodeDef* node) {
2608   if (!IsEnter(*node) || node->input_size() == 0 ||
2609       node->attr().count("is_constant") == 0 ||
2610       !node->attr().at("is_constant").b()) {
2611     return false;
2612   }
2613   const string& node_name = node->name();
2614   const NodeDef* input = node_map_->GetNode(node->input(0));
2615   if (input == nullptr || !IsReallyConstant(*input) ||
2616       OptimizedNodeExists(*input, "_enter")) {
2617     return false;
2618   }
2619   // Find non-constant nodes that consume the output of *node.
2620   std::vector<NodeDef*> consumers;
2621   for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2622     if (!IsConstant(*fanout)) {
2623       for (int i = 0; i < fanout->input_size(); ++i) {
2624         if (fanout->input(i) == node_name) {
2625           consumers.push_back(const_cast<NodeDef*>(fanout));
2626           break;
2627         }
2628       }
2629     }
2630   }
2631   if (consumers.empty()) {
2632     return false;
2633   }
2634   graph_modified_ = true;
2635   NodeDef* new_node = optimized_graph->add_node();
2636   *new_node = *input;
2637   new_node->set_name(OptimizedNodeName(*input, "_enter"));
2638   new_node->set_device(node->device());
2639   new_node->clear_input();
2640   new_node->add_input(AsControlDependency(node_name));
2641   node_map_->AddNode(new_node->name(), new_node);
2642   node_map_->AddOutput(node_name, new_node->name());
2643   for (NodeDef* consumer : consumers) {
2644     for (int i = 0; i < consumer->input_size(); ++i) {
2645       if (NodeName(consumer->input(i)) == node_name) {
2646         node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2647         consumer->set_input(i, new_node->name());
2648       }
2649     }
2650   }
2651   return true;
2652 }
2653 
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2654 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2655   if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2656       !OptimizedNodeExists(*node, "_const_false") &&
2657       !OptimizedNodeExists(*node, "_const_true")) {
2658     bool already_optimized = true;
2659     // If the optimization was already applied, the switch would have exactly
2660     // one Identity node consuming each of its outputs, each without any
2661     // non-control outputs.
2662     const auto& fanouts = node_map_->GetOutputs(node->name());
2663     if (fanouts.size() == 2) {
2664       for (const NodeDef* fanout : fanouts) {
2665         if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2666             HasRegularOutputs(*fanout, *node_map_)) {
2667           already_optimized = false;
2668           break;
2669         }
2670       }
2671     }
2672     Tensor false_t(DT_BOOL, TensorShape({}));
2673     Tensor true_t(DT_BOOL, TensorShape({}));
2674     // Make sure we don't proceed if this switch node was already optimized.
2675     if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2676         SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2677       // Copy the set of consumers of the switch as they will be manipulated
2678       // below.
2679       std::vector<NodeDef*> consumers =
2680           node_map_->GetOutputsOrderedByNodeName(node->name());
2681       // Create constant false & true nodes.
2682       NodeDef tmp_false_node;
2683       tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2684       if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2685                          &tmp_false_node)
2686                .ok()) {
2687         return false;
2688       }
2689       tmp_false_node.set_device(node->device());
2690       NodeDef tmp_true_node;
2691       tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2692       if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2693                          &tmp_true_node)
2694                .ok()) {
2695         return false;
2696       }
2697       tmp_true_node.set_device(node->device());
2698 
2699       // Add const nodes to graph.
2700       NodeDef* false_node = optimized_graph->add_node();
2701       false_node->Swap(&tmp_false_node);
2702       NodeDef* true_node = optimized_graph->add_node();
2703       true_node->Swap(&tmp_true_node);
2704 
2705       // Add controls from the switch ports to the constants, and connect the
2706       // constants to the original switch outputs.
2707       const string false_port = node->name();
2708       const string true_port = strings::StrCat(node->name(), ":1");
2709       const string false_ctrl_dep =
2710           AddControlDependency(false_port, optimized_graph, node_map_.get());
2711       false_node->add_input(false_ctrl_dep);
2712       const string true_ctrl_dep =
2713           AddControlDependency(true_port, optimized_graph, node_map_.get());
2714       true_node->add_input(true_ctrl_dep);
2715 
2716       node_map_->AddNode(false_node->name(), false_node);
2717       node_map_->AddNode(true_node->name(), true_node);
2718       node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2719       node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2720 
2721       for (NodeDef* consumer : consumers) {
2722         for (int i = 0; i < consumer->input_size(); ++i) {
2723           const string& input = consumer->input(i);
2724           if (input == false_port) {
2725             consumer->set_input(i, false_node->name());
2726             node_map_->UpdateInput(consumer->name(), false_port,
2727                                    false_node->name());
2728           } else if (input == true_port) {
2729             consumer->set_input(i, true_node->name());
2730             node_map_->UpdateInput(consumer->name(), true_port,
2731                                    true_node->name());
2732           }
2733         }
2734       }
2735       return true;
2736     }
2737   }
2738   return false;
2739 }
2740 
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2741 bool ConstantFolding::IsReductionWithConstantIndices(
2742     const NodeDef& node, bool* indices_is_empty) const {
2743   // Ensure its an appropriate Reduce node.
2744   if (!IsReduction(node) || node.input_size() < 2) {
2745     return false;
2746   }
2747   // Ensure that the axes to reduce by are constant.
2748   NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2749   if (!IsReallyConstant(*reductions_indices) ||
2750       !reductions_indices->attr().count("value")) {
2751     return false;
2752   }
2753   const TensorShapeProto& reduction_indices_shape =
2754       reductions_indices->attr().at("value").tensor().tensor_shape();
2755   *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2756   return true;
2757 }
2758 
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2759 bool ConstantFolding::IsReductionCandidateForSimplification(
2760     const NodeDef& node, const GraphProperties& properties,
2761     TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2762     bool* is_single_element_op) const {
2763   // Get the properties of the input & output tensors and check if they both
2764   // contain a single element.
2765   if (!properties.HasInputProperties(node.name()) ||
2766       !properties.HasOutputProperties(node.name())) {
2767     return false;
2768   }
2769   const auto& input_props = properties.GetInputProperties(node.name())[0];
2770   const auto& output_props = properties.GetOutputProperties(node.name())[0];
2771   if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2772       !output_props.has_shape() || output_props.shape().unknown_rank()) {
2773     return false;
2774   }
2775   *input_tensor_shape = input_props.shape();
2776   *output_tensor_shape = output_props.shape();
2777   for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2778     if (input_tensor_shape->dim(i).size() < 0) {
2779       return false;
2780     }
2781   }
2782   for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2783     if (output_tensor_shape->dim(i).size() < 0) {
2784       return false;
2785     }
2786   }
2787   const int input_num_elements =
2788       TensorShape(*input_tensor_shape).num_elements();
2789   const int output_num_elements =
2790       TensorShape(*output_tensor_shape).num_elements();
2791   *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2792 
2793   return true;
2794 }
2795 
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2796 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2797     const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2798     const TensorVector& reduction_indices_vector) const {
2799   int output_size = reduction_indices_vector[0]->NumElements();
2800   if (output_size == 0) {
2801     return true;
2802   }
2803 
2804   if (!keep_dims) {
2805     return false;
2806   }
2807   bool simplifiable = true;
2808   for (int i = 0; i < output_size; ++i) {
2809     int64_t dim;
2810     if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2811       dim = reduction_indices_vector[0]->flat<int32>()(i);
2812     } else {
2813       dim = reduction_indices_vector[0]->flat<int64>()(i);
2814     }
2815     if (dim < 0) {
2816       dim += input_shape.dim_size();
2817     }
2818     if (dim < 0 || dim >= input_shape.dim_size() ||
2819         input_shape.dim(dim).size() != 1) {
2820       simplifiable = false;
2821       break;
2822     }
2823   }
2824   return simplifiable;
2825 }
2826 
ReplaceReductionWithIdentity(NodeDef * node) const2827 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2828   // Replace the reduction node with an identity node, that can be further
2829   // optimized by other passes.
2830   DataType output_type;
2831   if (node->attr().count("T") != 0) {
2832     output_type = node->attr().at("T").type();
2833   } else if (IsAny(*node) || IsAll(*node)) {
2834     output_type = DT_BOOL;
2835   } else {
2836     return false;
2837   }
2838   node->set_op("Identity");
2839   EraseRegularNodeAttributes(node);
2840   (*node->mutable_attr())["T"].set_type(output_type);
2841   *node->mutable_input(1) = AsControlDependency(node->input(1));
2842   return true;
2843 }
2844 
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2845 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2846                                         const GraphProperties& properties,
2847                                         NodeDef* node) {
2848   bool indices_is_empty = false;
2849   if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2850     return false;
2851   }
2852   if (indices_is_empty) {
2853     return ReplaceReductionWithIdentity(node);
2854   }
2855   bool is_single_element_op = false;
2856   TensorShapeProto input_tensor_shape, output_tensor_shape;
2857   if (!IsReductionCandidateForSimplification(
2858           *node, properties, &input_tensor_shape, &output_tensor_shape,
2859           &is_single_element_op)) {
2860     return false;
2861   }
2862 
2863   // Get the reduction indices.
2864   string reduction_indices_input = node->input(1);
2865   NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2866   TensorVector reduction_indices_vector;
2867   auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2868     for (const auto& out : reduction_indices_vector) {
2869       delete out.tensor;
2870     }
2871   });
2872   if (!EvaluateNode(*reduction_indices, TensorVector(),
2873                     &reduction_indices_vector)
2874            .ok() ||
2875       reduction_indices_vector.size() != 1) {
2876     return false;
2877   }
2878 
2879   bool keep_dims =
2880       node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2881   bool simplifiable_to_reshape =
2882       is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2883   bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2884       *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2885 
2886   if (simplifiable_to_reshape) {
2887     // Const node to output shape.
2888     const int new_num_dimensions = output_tensor_shape.dim_size();
2889     Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2890     for (int i = 0; i < new_num_dimensions; i++) {
2891       tensor.flat<int>()(i) = 1;
2892     }
2893     TensorValue shape_value(&tensor);
2894     NodeDef* shape_node = optimized_graph->add_node();
2895     if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2896                        shape_node)
2897              .ok()) {
2898       return false;
2899     }
2900     shape_node->set_device(node->device());
2901     node_map_->AddNode(shape_node->name(), shape_node);
2902     // Control dependency to ensure shape_node is in the correct frame.
2903     shape_node->add_input(AsControlDependency(reduction_indices_input));
2904     node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2905     // Optimize node to Reshape.
2906     node->set_op("Reshape");
2907     node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2908     node->set_input(1, shape_node->name());
2909     node->mutable_attr()->erase("keep_dims");
2910     node->mutable_attr()->erase("Tidx");
2911     AttrValue attr_type_indices;
2912     attr_type_indices.set_type(DT_INT32);
2913     (*node->mutable_attr())["Tshape"] = attr_type_indices;
2914     return true;
2915   } else if (simplifiable_to_identity) {
2916     return ReplaceReductionWithIdentity(node);
2917   }
2918   return false;
2919 }
2920 
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2921 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2922                                       bool use_shape_info, NodeDef* node) {
2923   if (!use_shape_info || node->attr().count("T") == 0 ||
2924       !IsSimplifiableReshape(*node, properties)) {
2925     return false;
2926   }
2927   DataType output_type = node->attr().at("T").type();
2928   node->set_op("Identity");
2929   EraseRegularNodeAttributes(node);
2930   (*node->mutable_attr())["T"].set_type(output_type);
2931   *node->mutable_input(1) = AsControlDependency(node->input(1));
2932   return true;
2933 }
2934 
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2935 Status ConstantFolding::SimplifyArithmeticOperations(
2936     const GraphProperties& properties, bool use_shape_info,
2937     GraphDef* optimized_graph, NodeDef* node) {
2938   const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2939   const bool is_matmul = IsAnyMatMul(*node);
2940   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2941   const bool is_sub = IsSub(*node);
2942   const bool is_any_div = IsAnyDiv(*node) && !IsFloorDiv(*node);
2943   // Simplify arithmetic operations with ones or zeros.
2944   if (use_shape_info &&
2945       (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2946       properties.HasInputProperties(node->name()) &&
2947       properties.HasOutputProperties(node->name())) {
2948     const NodeDef* x = node_map_->GetNode(node->input(0));
2949     const NodeDef* y = node_map_->GetNode(node->input(1));
2950     if (x == nullptr || y == nullptr) {
2951       return errors::InvalidArgument("Invalid inputs to node: ",
2952                                      node->DebugString());
2953     }
2954     const TensorShapeProto& output_shape =
2955         properties.GetOutputProperties(node->name())[0].shape();
2956 
2957     // Simplify element-wise multiplication by ones or addition/subtraction
2958     // of zeros.
2959     const TensorShapeProto& y_shape =
2960         properties.GetInputProperties(node->name())[1].shape();
2961     const TensorShapeProto& x_shape =
2962         properties.GetInputProperties(node->name())[0].shape();
2963     const bool y_matches_output_shape =
2964         ShapesSymbolicallyEqual(output_shape, y_shape);
2965     const bool x_matches_output_shape =
2966         ShapesSymbolicallyEqual(output_shape, x_shape);
2967 
2968     const bool x_is_zero = IsZeros(*x);
2969     const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2970     if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2971       // 1 * y = y or 0 + y = y.
2972       if (y_matches_output_shape) {
2973         ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2974       } else if (x_matches_output_shape) {
2975         ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
2976                                               optimized_graph);
2977       }
2978       return Status::OK();
2979     }
2980 
2981     if (y_matches_output_shape && (is_sub && x_is_zero)) {
2982       // Replace 0 - y with Neg(y).
2983       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2984       return Status::OK();
2985     }
2986 
2987     // Replace 1 / y with Reciprocal op.
2988     if (y_matches_output_shape && is_any_div && x_is_one) {
2989       TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2990       DataType type = node->attr().at("T").type();
2991       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2992         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2993         return Status::OK();
2994       }
2995     }
2996 
2997     const bool y_is_zero = IsZeros(*y);
2998     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2999     if (((is_mul || is_any_div) && y_is_one) ||
3000         ((is_add || is_sub) && y_is_zero)) {
3001       // x * 1 = x or x / 1 = x or x +/- 0 = x
3002       if (x_matches_output_shape) {
3003         ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
3004       } else if (y_matches_output_shape) {
3005         ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3006                                               optimized_graph);
3007       }
3008       return Status::OK();
3009     }
3010 
3011     // x OR true = true OR y = true.
3012     const PartialTensorShape shp(output_shape);
3013     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
3014       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3015           1, properties, output_shape, node, optimized_graph));
3016       return Status::OK();
3017     }
3018 
3019     // Simplify multiplication and matmul by zeros.
3020     // Also optimize zeros divided by a tensor, but only if we are in
3021     // aggressive mode, since we might get rid of divisions by zero.
3022     const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3023     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3024     if ((x_is_zero || y_is_zero) &&
3025         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3026       if (shp.IsFullyDefined()) {
3027         bool is_quantized = IsQuantizedMatMul(*node);
3028         TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3029             0, properties, output_shape, node, optimized_graph));
3030         if (is_quantized && graph_modified_) {
3031           TF_RETURN_IF_ERROR(
3032               AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3033         }
3034         return Status::OK();
3035       }
3036       // Even if an input shape is only partially known, we may known that it
3037       // matches the output shape and thus forward or broadcast the
3038       // corresponding zero input.
3039       if ((is_mul || is_any_div) && x_is_zero) {
3040         if (x_matches_output_shape) {
3041           ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3042         } else if (y_matches_output_shape) {
3043           ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3044                                                 optimized_graph);
3045         }
3046         return Status::OK();
3047       } else if (is_mul && y_is_zero) {
3048         if (y_matches_output_shape) {
3049           ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3050         } else if (x_matches_output_shape) {
3051           ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3052                                                 optimized_graph);
3053         }
3054         return Status::OK();
3055       }
3056     }
3057   }
3058   return Status::OK();
3059 }
3060 
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3061 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3062                                                NodeDef* node) {
3063   // Strength reduce floating point division by a constant Div(x, const) to
3064   // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3065   // will be constant folded to Mul(x, 1.0/const).
3066   if (node->input_size() >= 2 &&
3067       (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3068     const string& const_input = node->input(1);
3069     const NodeDef* denom = node_map_->GetNode(const_input);
3070     CHECK(denom != nullptr);
3071     if (!IsReallyConstant(*denom)) {
3072       return false;
3073     }
3074     if (node->attr().count("T") == 0) {
3075       return false;
3076     }
3077     DataType type = node->attr().at("T").type();
3078     // Skip integer division.
3079     if (IsDiv(*node) &&
3080         !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3081       return false;
3082     }
3083     // Insert new reciprocal op and change node from Div to Mul.
3084     NodeDef* reciprocal_node = optimized_graph->add_node();
3085     reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3086     reciprocal_node->set_op("Reciprocal");
3087     reciprocal_node->set_device(node->device());
3088     reciprocal_node->add_input(const_input);
3089     (*reciprocal_node->mutable_attr())["T"].set_type(type);
3090 
3091     // Re-wire inputs and outputs.
3092     if (IsXdivy(*node)) {
3093       node->set_op("MulNoNan");
3094       node->set_input(1, node->input(0));
3095       node->set_input(0, reciprocal_node->name());
3096     } else {
3097       node->set_op("Mul");
3098       node->set_input(1, reciprocal_node->name());
3099     }
3100     node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3101     node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3102 
3103     return true;
3104   }
3105 
3106   return false;
3107 }
3108 
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3109 bool ConstantFolding::PrepareConstantPushDown(
3110     const NodeDef& parent, const GraphProperties& properties,
3111     bool must_have_properties, ConstantPushDownContext* ctx) const {
3112   if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3113     return false;
3114   }
3115   NodeDef* left_child = node_map_->GetNode(parent.input(0));
3116   NodeDef* right_child = node_map_->GetNode(parent.input(1));
3117 
3118   // Sanity check for missing children.
3119   if (left_child == nullptr || right_child == nullptr) {
3120     return false;
3121   }
3122 
3123   ctx->left_child_is_const = IsReallyConstant(*left_child);
3124   ctx->right_child_is_const = IsReallyConstant(*right_child);
3125   ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3126   ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3127 
3128   // Nothing to do unless the parent has a constant child node.
3129   if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3130     return false;
3131   }
3132 
3133   // Don't move nodes across devices.
3134   if (parent.device() != ctx->op_child->device() ||
3135       parent.device() != ctx->const_child->device()) {
3136     return false;
3137   }
3138 
3139   // Make sure that it is safe to change the value of the child node result.
3140   if (ctx->op_child->input_size() < 2 ||
3141       nodes_to_preserve_.find(ctx->op_child->name()) !=
3142           nodes_to_preserve_.end() ||
3143       NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3144     return false;
3145   }
3146 
3147   // Don't apply reassociation to floating point types of low precision.
3148   // The danger of significant numerical changes is too high.
3149   if (!CheckAttrExists(parent, "T").ok()) return false;
3150   DataType dtype = parent.attr().at("T").type();
3151   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3152     return false;
3153   }
3154 
3155   // Don't rewrite the tree if it might create cycles.
3156   // TODO(rmlarsen): Add back handling of control dependency from op to C.
3157   const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3158   if (child_output.find(ctx->const_child) != child_output.end()) {
3159     return false;
3160   }
3161 
3162   // Get leaf nodes.
3163   ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3164   ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3165   ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3166   ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3167 
3168   if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3169     // Child is already foldable, leave it alone.
3170     return false;
3171   }
3172 
3173   // Don't move nodes across devices.
3174   if (parent.device() != ctx->left_leaf->device() ||
3175       parent.device() != ctx->right_leaf->device()) {
3176     return false;
3177   }
3178 
3179   // Get shape and type information.
3180   ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3181   ctx->op_child_input_props =
3182       &properties.GetInputProperties(ctx->op_child->name());
3183   if (must_have_properties && (ctx->parent_input_props == nullptr ||
3184                                ctx->parent_input_props->size() < 2 ||
3185                                ctx->op_child_input_props == nullptr ||
3186                                ctx->op_child_input_props->size() < 2)) {
3187     return false;
3188   }
3189 
3190   VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3191           << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3192           << ")";
3193 
3194   return true;
3195 }
3196 
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3197 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3198                                               GraphDef* optimized_graph,
3199                                               NodeDef* node) {
3200   // This implements constant push-down for BiasAdd. In the following "CV" is a
3201   // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3202   // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3203   // non-constant matrix, and "BA" is BiasAdd.
3204   // For a valid input graph, the following 4 rewrites are legal:
3205   //
3206   //  1)                  +                +
3207   //                     / \              / \
3208   //                    BA  CV    -- >   BA  V
3209   //                   / \              / \
3210   //                  M   V            M   CV
3211   //
3212   //  2)                  +                +
3213   //                     / \              / \
3214   //                    BA  CM    -- >   BA  M
3215   //                   / \              / \
3216   //                  M   V            CM  V
3217   //
3218   //  3)                  BA               BA
3219   //                     / \              / \
3220   //                    +  CV     -- >   +   V
3221   //                   / \              / \
3222   //                  M   V            M  CV
3223   //
3224   //  4)                  BA               BA      = parent
3225   //                     / \              / \
3226   //                    BA  CV    -- >   BA  V     = children
3227   //                   / \              / \
3228   //                  M   V            M  CV       = leaves
3229   //
3230   // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3231 
3232   const bool parent_is_bias_add = IsBiasAdd(*node);
3233   if (!parent_is_bias_add && !IsAdd(*node)) return false;
3234   ConstantPushDownContext ctx;
3235   if (!PrepareConstantPushDown(*node, *properties,
3236                                /*must_have_properties=*/true, &ctx)) {
3237     return false;
3238   }
3239   // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3240   // >= 2 and the leaves must be vectors, we cannot swap them.
3241   if (ctx.left_child_is_const && parent_is_bias_add) return false;
3242   const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3243   if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3244 
3245   // Get properties to validate rank and dtype constraints.
3246   if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3247       (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3248       (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3249       (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3250       (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3251     return false;
3252   }
3253 
3254   // Now get the ranks and types of the 3 leaf nodes.
3255   const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3256   const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3257   // At least one leaf must be a vector.
3258   if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3259   const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3260   const int matrix_idx = 1 - vector_idx;
3261 
3262   const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3263   const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3264   if (vector_rank != 1) return false;  // this should never happen.
3265   const DataType vector_type = vector_prop.dtype();
3266 
3267   const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3268   const int matrix_rank = matrix_prop.shape().dim_size();
3269   const DataType matrix_type = matrix_prop.dtype();
3270 
3271   const int const_idx = ctx.left_child_is_const ? 0 : 1;
3272   const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3273   const int const_rank = const_prop.shape().dim_size();
3274   const DataType const_type = const_prop.dtype();
3275 
3276   int input_to_swap = -1;
3277 
3278   if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3279       const_type == matrix_type) {
3280     // Case 2:
3281     input_to_swap = matrix_idx;
3282   } else if (const_rank == 1 && const_type == vector_type) {
3283     // Case 1, 3, and, 4:
3284     input_to_swap = vector_idx;
3285   }
3286   if (input_to_swap == -1) return false;
3287   const NodeDef* leaf_to_swap =
3288       node_map_->GetNode(ctx.op_child->input(input_to_swap));
3289   if (IsConstant(*leaf_to_swap)) return false;
3290 
3291   node_map_->UpdateInput(node->name(), node->input(const_idx),
3292                          ctx.op_child->input(input_to_swap));
3293   node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3294   if (ctx.op_child->input(input_to_swap) !=
3295       ctx.op_child->input(1 - input_to_swap)) {
3296     node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3297                             ctx.op_child->name());
3298   }
3299   std::swap(*node->mutable_input(const_idx),
3300             *ctx.op_child->mutable_input(input_to_swap));
3301   properties->ClearInputProperties(node->name());
3302   properties->ClearInputProperties(ctx.op_child->name());
3303 
3304   return true;
3305 }
3306 
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3307 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3308                                        GraphDef* optimized_graph,
3309                                        NodeDef* node) {
3310   // Consider the transformation
3311   //
3312   //                      +                +       = parent
3313   //                     / \              / \
3314   //                    C   +    -- >    X   +     = children
3315   //                       / \              / \
3316   //                      X   Y            C   Y   = leaves
3317   //
3318   // where C is constant, X is non-constant, Y may be constant or non-constant,
3319   // and '+' denotes an associative and commutative operator like addition or
3320   // multiplication. This optimization pushes constants down in the tree to
3321   // canonicalize it. Moreover, in cases where the child node has a second
3322   // constant input Y we will create a leaf node that can be folded, e.g.
3323   //
3324   //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3325   //
3326   // We also handle the non-commutative cases of subtraction and division
3327   // by rotating the tree locally, e.g.
3328   //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3329   //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3330 
3331   // Get parent op type.
3332   const bool is_add = IsAdd(*node);
3333   const bool is_mul = IsMul(*node);
3334   const bool is_sub = IsSub(*node);
3335   const bool is_div = IsDiv(*node);
3336   if (!(is_add || is_sub || is_mul || is_div)) return false;
3337   const bool is_symmetric = is_add || is_mul;
3338 
3339   ConstantPushDownContext ctx;
3340   if (!PrepareConstantPushDown(*node, *properties,
3341                                /*must_have_properties=*/false, &ctx)) {
3342     return false;
3343   }
3344 
3345   // Get child op type.
3346   const bool is_child_add = IsAdd(*ctx.op_child);
3347   const bool is_child_mul = IsMul(*ctx.op_child);
3348   const bool is_child_sub = IsSub(*ctx.op_child);
3349   const bool is_child_div = IsDiv(*ctx.op_child);
3350   const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3351   const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3352   if (!is_add_sub && !is_mul_div) {
3353     return false;
3354   }
3355   const bool is_child_symmetric = is_child_add || is_child_mul;
3356 
3357   if (!CheckAttrExists(*node, "T").ok()) return false;
3358   DataType dtype = node->attr().at("T").type();
3359   if (!(is_symmetric && is_child_symmetric) &&
3360       !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3361     return false;
3362   }
3363 
3364   const NodeDef* y_node =
3365       ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3366   if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3367       !ctx.op_child_input_props->empty()) {
3368     // If we know the shapes of the nodes being swapped, make sure we don't push
3369     // down a larger node and create more work by broadcasting earlier in the
3370     // expressions tree.
3371     const PartialTensorShape c_shape(
3372         (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3373     const PartialTensorShape x_shape(
3374         (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3375 
3376     if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3377         c_shape.num_elements() > x_shape.num_elements()) {
3378       return false;
3379     } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3380                c_shape.dims() > 0) {
3381       for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3382         if (x_shape.dim_size(idx) >= 0 &&
3383             c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3384           return false;
3385         }
3386       }
3387     }
3388   }
3389 
3390   // Get the node names corresponding to X, Y, and C.
3391   const string input_x =
3392       ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3393   const string input_y = input_x == ctx.op_child->input(0)
3394                              ? ctx.op_child->input(1)
3395                              : ctx.op_child->input(0);
3396   const string input_c =
3397       ctx.left_child_is_const ? node->input(0) : node->input(1);
3398   const string input_op =
3399       ctx.left_child_is_const ? node->input(1) : node->input(0);
3400   VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3401 
3402   // Now we have identified the nodes to swap, update the nodemap accordingly.
3403   node_map_->UpdateInput(node->name(), input_c, input_x);
3404   node_map_->AddOutput(input_c, ctx.op_child->name());
3405   if (input_x != input_y) {
3406     node_map_->RemoveOutput(input_x, ctx.op_child->name());
3407   }
3408   properties->ClearInputProperties(node->name());
3409   properties->ClearInputProperties(ctx.op_child->name());
3410 
3411   if (is_symmetric && is_child_symmetric) {
3412     // Easy case (only commutative ops). We always write this as one of
3413     //   +
3414     //  / \
3415     // X   +
3416     //    / \
3417     //   C   Y
3418     node->set_input(0, input_x);
3419     node->set_input(1, input_op);
3420     ctx.op_child->set_input(0, input_c);
3421     ctx.op_child->set_input(1, input_y);
3422   } else {
3423     // More complicated case: When there are non-commutative operations like
3424     // subtractions or divisions involved, we may have to rotate the tree
3425     // and/or change op types. There are 6 non-trivial cases depending on
3426     // the effective generalized "sign" of each of the three terms C, Y, and X.
3427     // Here are the final trees we want to generate for those 6 cases:
3428     //
3429     // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
3430     //
3431     //                 -        -        -      -       +        +
3432     //                / \      / \      / \    / \     / \      / \
3433     //               +   X    -   X    -   X  X   +   X   -    X   -
3434     //              / \      / \      / \        / \     / \      / \
3435     //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
3436     //
3437 
3438     // First, let's determine the effective sign of each term in the original
3439     // expression
3440     auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3441       bool leaf_negated = !is_child_symmetric && is_right_leaf;
3442       bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3443       return leaf_negated != child_negated;
3444     };
3445     const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3446     const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3447     bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3448     bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3449     bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3450     // Rewrite the parent node.
3451     node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3452     node->set_input(0, neg_x ? input_op : input_x);
3453     node->set_input(1, neg_x ? input_x : input_op);
3454     // Rewrite the child node.
3455     ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3456     ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3457     ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3458   }
3459   return true;
3460 }
3461 
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3462 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3463                                       const GraphProperties& properties) {
3464   // Push down multiplication on ConvND.
3465   //                       *                  ConvND
3466   //                     /   \                /    \
3467   //                 ConvND  C2    -- >      X      *
3468   //                  / \                          / \
3469   //                 X  C1                       C1  C2
3470   //
3471   // where C1 and C2 are constants and X is non-constant.
3472   //
3473   // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3474 
3475   if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3476 
3477   NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3478   NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3479   // One child must be constant, and the second must be Conv op.
3480   const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3481   const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3482   if (!left_child_is_constant && !right_child_is_constant) {
3483     return false;
3484   }
3485   NodeDef* conv_node =
3486       left_child_is_constant ? mul_right_child : mul_left_child;
3487   if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3488     return false;
3489   }
3490   if (node->device() != mul_left_child->device() ||
3491       node->device() != mul_right_child->device()) {
3492     return false;
3493   }
3494 
3495   // Make sure that it is safe to change the value of the convolution
3496   // output.
3497   if (conv_node->input_size() < 2 ||
3498       NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3499       nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3500     return false;
3501   }
3502 
3503   // Identify the nodes to swap.
3504   NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3505   NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3506   const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3507   const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3508   if (!conv_left_is_constant && !conv_right_is_constant) {
3509     // At least one of the convolution inputs should be constant.
3510     return false;
3511   }
3512   if (conv_left_is_constant && conv_right_is_constant) {
3513     // Leverage regular constant folding to handle this.
3514     return false;
3515   }
3516   const auto& mul_props = properties.GetOutputProperties(node->name());
3517   const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3518   if (mul_props.empty() || conv_props.empty()) {
3519     return false;
3520   }
3521   const auto& mul_shape = mul_props[0].shape();
3522   const auto& conv_shape = conv_props[0].shape();
3523   if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3524     return false;
3525   }
3526 
3527   const auto& input_props = properties.GetInputProperties(conv_node->name());
3528   if (input_props.size() < 2) {
3529     return false;
3530   }
3531   const auto& filter_shape = input_props[1].shape();
3532 
3533   NodeDef* const_node =
3534       left_child_is_constant ? mul_left_child : mul_right_child;
3535   const auto& const_props = properties.GetOutputProperties(const_node->name());
3536   if (const_props.empty()) {
3537     return false;
3538   }
3539   const auto& const_shape = const_props[0].shape();
3540   if (!IsValidConstShapeForMulConvPushDown(
3541           conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3542     return false;
3543   }
3544 
3545   string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3546   if (node_map_->NodeExists(mul_new_name)) {
3547     return false;
3548   }
3549   // Make sure we don't introduce loops in the graph by removing control
3550   // dependencies from the conv2d node to c2.
3551   string conv_const_input =
3552       conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3553   if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3554                               node_map_.get())) {
3555     // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3556     MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3557                          node_map_.get());
3558   }
3559 
3560   conv_node->set_name(node->name());
3561   node->set_name(mul_new_name);
3562   if (conv_left_is_constant) {
3563     node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3564     conv_node->set_input(0, mul_new_name);
3565   } else {
3566     node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3567     conv_node->set_input(1, mul_new_name);
3568   }
3569   NodeDef* conv_const_node =
3570       conv_left_is_constant ? conv_left_child : conv_right_child;
3571   if (left_child_is_constant) {
3572     node->set_input(1, conv_const_node->name());
3573   } else {
3574     node->set_input(0, conv_const_node->name());
3575   }
3576   node_map_->AddNode(mul_new_name, node);
3577 
3578   return true;
3579 }
3580 
PartialConstPropThroughIdentityN(NodeDef * node)3581 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3582   // Partial constant propagation through IdentityN.
3583   if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3584       !HasRegularInputs(*node))
3585     return false;
3586 
3587   std::vector<int> inputs_to_forward;
3588   for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3589     const string& input = node->input(input_idx);
3590     if (IsControlInput(input)) {
3591       return false;
3592     }
3593     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3594     if (input_node == nullptr) {
3595       LOG(ERROR) << "Bad input: " << input;
3596       return false;
3597     }
3598     // Forward constant inputs to outputs and add a control dependency on
3599     // the IdentityN node.
3600     if (IsReallyConstant(*input_node)) {
3601       inputs_to_forward.push_back(input_idx);
3602     }
3603   }
3604   return ForwardInputs(node, inputs_to_forward);
3605 }
3606 
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3607 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3608                                                  GraphProperties* properties,
3609                                                  NodeDef* node) {
3610   // Partial constant folding for associative operators:
3611   // Split AddN/AccumulateNV2 to enable partial
3612   // folding of ops when more than one but not all inputs are constant.
3613   // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3614   // addition is commutative.
3615   if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3616 
3617   const int num_non_control_inputs = NumNonControlInputs(*node);
3618   if (num_non_control_inputs <= 2) return false;
3619   const int num_control_inputs = node->input_size() - num_non_control_inputs;
3620   std::vector<int> const_inputs;
3621   std::vector<int> nonconst_inputs;
3622   for (int i = 0; i < node->input_size(); ++i) {
3623     const string& input = node->input(i);
3624     const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3625     if (input_node == nullptr) return false;
3626     if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3627       const_inputs.push_back(i);
3628     } else {
3629       // Non-const and control inputs.
3630       nonconst_inputs.push_back(i);
3631     }
3632   }
3633   // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3634   // a fake node that cannot be constant folded by itself.
3635   int const_inputs_size = const_inputs.size();
3636   if (const_inputs_size == num_non_control_inputs &&
3637       node->op() == "AccumulateNV2") {
3638     node->set_op("AddN");
3639     node->mutable_attr()->erase("shape");
3640     return true;
3641   }
3642   const string new_node_name = OptimizedNodeName(
3643       *node, strings::StrCat("_partial_split_", const_inputs_size));
3644   if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3645       !node_map_->NodeExists(new_node_name)) {
3646     NodeDef* added_node = optimized_graph->add_node();
3647     *added_node = *node;
3648     // Always use AddN for the constant node, since AccumulateNV2 is a fake
3649     // node that cannot be constant folded, since it does not have a kernel.
3650     added_node->set_op("AddN");
3651     added_node->mutable_attr()->erase("shape");
3652     added_node->set_name(new_node_name);
3653     node_map_->AddNode(added_node->name(), added_node);
3654     added_node->clear_input();
3655     for (int i : const_inputs) {
3656       added_node->add_input(node->input(i));
3657       node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3658                               added_node->name());
3659     }
3660 
3661     // Overwrite the first const input with the added node.
3662     node->set_input(const_inputs[0], added_node->name());
3663     node_map_->AddOutput(added_node->name(), node->name());
3664     nonconst_inputs.push_back(const_inputs[0]);
3665     // Compact the remaining inputs to the original node.
3666     std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3667     int idx = 0;
3668     for (int i : nonconst_inputs) {
3669       if (idx != i) {
3670         node->set_input(idx, node->input(i));
3671       }
3672       ++idx;
3673     }
3674     node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3675                                           const_inputs.size() - 1);
3676     (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3677     properties->ClearInputProperties(node->name());
3678     (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3679     return true;
3680   }
3681   return false;
3682 }
3683 
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3684 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3685                                                 GraphProperties* properties,
3686                                                 NodeDef* node) {
3687   // Partial constant folding for Concat which is not commutative, so
3688   // we have to preserve order and can only push consecutive runs of constant
3689   // inputs into sub-nodes.
3690   if (!IsConcat(*node) ||
3691       node->name().rfind("_partial_split_") != string::npos) {
3692     return false;
3693   }
3694   const int num_non_control_inputs = NumNonControlInputs(*node);
3695   if (num_non_control_inputs <= 3) return false;
3696   int axis_arg = -1;
3697   int begin = 0;
3698   int end = num_non_control_inputs;
3699   if (node->op() == "Concat") {
3700     begin = 1;
3701     axis_arg = 0;
3702   } else if (node->op() == "ConcatV2") {
3703     end = num_non_control_inputs - 1;
3704     axis_arg = num_non_control_inputs - 1;
3705   } else {
3706     return false;
3707   }
3708 
3709   // We search for consecutive runs of constant inputs in the range
3710   // [begin:end[ and push then down into child nodes.
3711   std::vector<std::pair<int, int>> constant_input_runs;
3712   int first = begin;
3713   int last = begin;
3714   while (last < end) {
3715     while (first < end && !IsReallyConstant(*node_map_->GetNode(
3716                               NodeName(node->input(first))))) {
3717       ++first;
3718     }
3719     // Invariant: node[first] is constant || first >= end.
3720     last = first + 1;
3721     while (last < end &&
3722            IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3723       ++last;
3724     }
3725     // Invariant: node[last] is not constant || last >= end
3726     // Discard intervals shorter than 2 elements.
3727     if (first < end && (last - first) > 1) {
3728       constant_input_runs.emplace_back(first, last);
3729     }
3730     first = last;
3731   }
3732 
3733   // Skip if all inputs are constant, and let constant folding take over.
3734   if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3735                                       constant_input_runs[0].first == begin &&
3736                                       constant_input_runs[0].second == end)) {
3737     return false;
3738   }
3739   std::set<int> inputs_to_delete;
3740   for (auto interval : constant_input_runs) {
3741     // Push the constant inputs in the interval to a child node than can be
3742     // constant folded.
3743     string new_node_name = OptimizedNodeName(*node, "_partial_split");
3744     do {
3745       new_node_name += strings::StrCat("_", interval.first);
3746     } while (node_map_->NodeExists(new_node_name));
3747 
3748     NodeDef* added_node = optimized_graph->add_node();
3749     *added_node = *node;
3750     added_node->set_op("ConcatV2");
3751     added_node->set_name(new_node_name);
3752     node_map_->AddNode(added_node->name(), added_node);
3753     added_node->clear_input();
3754     for (int i = interval.first; i < interval.second; ++i) {
3755       added_node->add_input(node->input(i));
3756       node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3757       if (i != interval.first) {
3758         inputs_to_delete.insert(i);
3759       }
3760     }
3761     added_node->add_input(node->input(axis_arg));
3762     (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3763     node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3764 
3765     // Overwrite the first constant input with the result of the added
3766     // child node.
3767     node->set_input(interval.first, added_node->name());
3768   }
3769   if (!inputs_to_delete.empty()) {
3770     // Fix up the inputs to the original node.
3771     protobuf::RepeatedPtrField<string> tmp;
3772     tmp.Swap(node->mutable_input());
3773     for (int i = 0; i < tmp.size(); ++i) {
3774       if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3775         node->add_input(tmp.Get(i));
3776       }
3777     }
3778     (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3779     properties->ClearInputProperties(node->name());
3780   }
3781   return true;
3782 }
3783 
GetConcatAxis(const NodeDef & node,int * axis)3784 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3785   if (node.op() != "ConcatV2") {
3786     return false;
3787   }
3788   int axis_idx = node.input_size() - 1;
3789   while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3790     --axis_idx;
3791   }
3792   if (axis_idx <= 0) {
3793     return false;
3794   }
3795   Tensor axis_tensor;
3796   if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3797     return false;
3798   }
3799   *axis = axis_tensor.dtype() == DT_INT64
3800               ? static_cast<int>(axis_tensor.scalar<int64>()())
3801               : axis_tensor.scalar<int32>()();
3802   return true;
3803 }
3804 
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3805 bool ConstantFolding::MergeConcat(bool use_shape_info,
3806                                   GraphProperties* properties,
3807                                   GraphDef* optimized_graph, NodeDef* node) {
3808   // We only optimize for ConcatV2.
3809   int axis;
3810   if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3811       nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3812       node_map_->GetOutputs(node->name()).size() != 1) {
3813     return false;
3814   }
3815 
3816   // If all inputs are constant, don't merge and let folding take case of it.
3817   const int num_regular_inputs = NumNonControlInputs(*node);
3818   bool all_inputs_are_const = true;
3819   for (int i = 0; i < num_regular_inputs - 1; ++i) {
3820     const NodeDef* input_node = node_map_->GetNode(node->input(i));
3821     if (!IsReallyConstant(*input_node)) {
3822       all_inputs_are_const = false;
3823       break;
3824     }
3825   }
3826   if (all_inputs_are_const) return false;
3827 
3828   NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3829   int parent_axis;
3830   if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3831     return false;
3832   }
3833 
3834   // Make a pass over the parent inputs to see if any of them have explicit
3835   // device() fields set, and if different inputs are on different tasks.  If
3836   // so, this concat of concats may have been carefully constructed to be a
3837   // two-stage concat, and we don't want to undo that here.
3838   string task, device;
3839   absl::flat_hash_set<string> unique_input_tasks;
3840   const int n_parent_inputs = NumNonControlInputs(*parent);
3841   // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1).  The
3842   // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3843   // node, which we don't want to consider here.
3844   for (int i = 0; i < n_parent_inputs - 1; ++i) {
3845     const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3846     if (!input_node->device().empty() &&
3847         tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3848                                                      &task, &device)) {
3849       unique_input_tasks.insert(task);
3850       if (unique_input_tasks.size() >= 2) {
3851         // More than one input task represented in the device specifications
3852         // of the parent's input nodes.  Don't mess with this.
3853         return false;
3854       }
3855     }
3856   }
3857 
3858   protobuf::RepeatedPtrField<string> parent_inputs;
3859   parent_inputs.Swap(parent->mutable_input());
3860   // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3861   // collapse it into the parent multiple times? Probably not.
3862   for (const auto& input : parent_inputs) {
3863     if (IsSameInput(input, node->name())) {
3864       for (int j = 0; j < num_regular_inputs - 1; ++j) {
3865         // Add tensor inputs to first child concat tensors (except the final
3866         // axis input) to the parent's inputs.
3867         parent->add_input(node->input(j));
3868         node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3869       }
3870     } else {
3871       parent->add_input(input);
3872     }
3873   }
3874   // Forward Add control inputs
3875   const int num_inputs = node->input_size();
3876   for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3877     parent->add_input(node->input(i));
3878     node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3879     node->mutable_input()->RemoveLast();
3880   }
3881   (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3882   DedupControlInputs(parent);
3883   ReplaceOperationWithNoOp(node, properties, optimized_graph);
3884 
3885   return true;
3886 }
3887 
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3888 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3889     NodeDef* node, GraphDef* optimized_graph) {
3890   auto add_quantized_out = [this, node, optimized_graph](
3891                                const string& out_const_name, int index) {
3892     NodeDef* out_node = optimized_graph->add_node();
3893     graph_modified_ = true;
3894     Tensor value(DT_FLOAT, TensorShape({}));
3895     const bool is_min = index == 1;
3896     const DataType type_attr = node->attr().at("dtype").type();
3897 
3898     value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3899                                     : QuantizedTypeMaxAsFloat(type_attr);
3900     TF_RETURN_IF_ERROR(
3901         CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3902     node_map_->AddNode(out_const_name, out_node);
3903     out_node->set_device(node->device());
3904     // Copy all inputs from node.
3905     out_node->mutable_input()->CopyFrom(node->input());
3906     for (const string& input : out_node->input()) {
3907       node_map_->AddOutput(NodeName(input), out_const_name);
3908     }
3909 
3910     // Update output nodes consuming node:index to new const node.
3911     string old_input = absl::StrCat(node->name(), ":", index);
3912     int old_node_count = 0;
3913     // We make a copy since the set might change.
3914     auto outputs = node_map_->GetOutputs(node->name());
3915     for (const auto& output : outputs) {
3916       for (int i = 0; i < output->input_size(); ++i) {
3917         if (output->input(i) == old_input) {
3918           output->set_input(i, out_const_name);
3919           node_map_->AddOutput(out_const_name, output->name());
3920         } else if (NodeName(output->input(i)) == node->name()) {
3921           ++old_node_count;
3922         }
3923       }
3924       if (old_node_count == 0) {
3925         node_map_->RemoveOutput(node->name(), output->name());
3926       }
3927     }
3928 
3929     return Status::OK();
3930   };
3931   const string min_out_const_name =
3932       OptimizedNodeName(*node, "-quantized_matmul_min_out");
3933   const string max_out_const_name =
3934       OptimizedNodeName(*node, "-quantized_matmul_max_out");
3935   if (node_map_->GetNode(min_out_const_name) == nullptr &&
3936       node_map_->GetNode(max_out_const_name) == nullptr) {
3937     TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3938     TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3939   } else {
3940     return errors::Internal(absl::Substitute(
3941         "Can't create Const for QuantizedMatMul min_out/max_out of "
3942         "node '$0' because of node name conflict",
3943         node->name()));
3944   }
3945   return Status::OK();
3946 }
3947 
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphDef * optimized_graph)3948 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3949                                             GrapplerItem* item,
3950                                             GraphDef* optimized_graph) {
3951   graph_ = &item->graph;
3952   node_map_.reset(new NodeMap(graph_));
3953   nodes_allowlist_.clear();
3954   // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3955   // has a single fanout, it would be rewritten as a constant with the same
3956   // node name, and therefore users are still able to fetch it. This is not
3957   // the case if the node has multiple fanouts, and constant folding would
3958   // replace the node with multiple constants (each for one fanout) with
3959   // new names, and as a result users would not be able to fetch the node any
3960   // more with the original node name.
3961   for (const auto& fetch : item->fetch) {
3962     const NodeDef* fetch_node = node_map_->GetNode(fetch);
3963     if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3964       nodes_allowlist_.insert(fetch_node->name());
3965     }
3966   }
3967 
3968   GraphProperties properties(*item);
3969   // It's possible to feed a placeholder with a tensor of any shape: make sure
3970   // that the shape inference deals with this conservatively unless we're in
3971   // aggressive mode.
3972   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3973   Status s = properties.InferStatically(assume_valid_feeds,
3974                                         /*aggressive_shape_inference=*/false,
3975                                         /*include_input_tensor_values=*/false,
3976                                         /*include_output_tensor_values=*/true);
3977 
3978   const bool can_use_shape_info = s.ok();
3979   VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
3980 
3981   absl::flat_hash_set<string> nodes_to_not_simplify;
3982   if (can_use_shape_info) {
3983     TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3984     TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3985     TF_RETURN_IF_ERROR(
3986         FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
3987   } else {
3988     *optimized_graph = *graph_;
3989   }
3990   node_map_.reset(new NodeMap(optimized_graph));
3991   TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3992                                    &properties, &nodes_to_not_simplify));
3993 
3994   return Status::OK();
3995 }
3996 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3997 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3998                                  GraphDef* optimized_graph) {
3999   // TensorFlow flushes denormals to zero and rounds to nearest, so we do
4000   // the same here.
4001   port::ScopedFlushDenormal flush;
4002   port::ScopedSetRound round(FE_TONEAREST);
4003   nodes_to_preserve_ = item.NodesToPreserve();
4004   for (const auto& feed : item.feed) {
4005     feed_nodes_.insert(NodeName(feed.first));
4006   }
4007 
4008   if (cpu_device_ == nullptr) {
4009     owned_device_.reset(new DeviceSimple());
4010     cpu_device_ = owned_device_.get();
4011   }
4012 
4013   graph_contains_assign_or_inplace_op_ = false;
4014   for (const NodeDef& node : item.graph.node()) {
4015     if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
4016       graph_contains_assign_or_inplace_op_ = true;
4017       break;
4018     }
4019   }
4020 
4021   has_fetch_ = !item.fetch.empty();
4022   GrapplerItem item_to_optimize = item;
4023   *optimized_graph = GraphDef();
4024   item_to_optimize.graph.Swap(optimized_graph);
4025   int64_t node_count;
4026   do {
4027     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4028     graph_modified_ = false;
4029     item_to_optimize.graph.Swap(optimized_graph);
4030     optimized_graph->Clear();
4031     node_count = item_to_optimize.graph.node_size();
4032     TF_RETURN_IF_ERROR(
4033         RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
4034   } while (graph_modified_ || optimized_graph->node_size() != node_count);
4035   *optimized_graph->mutable_library() = item.graph.library();
4036   *optimized_graph->mutable_versions() = item.graph.versions();
4037 
4038   return Status::OK();
4039 }
4040 
4041 }  // namespace grappler
4042 }  // namespace tensorflow
4043