1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19
20 #include <cmath>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61
62 // We only fold/materialize constants smaller than 100kB.
63 const int64_t kMaxConstantSize = 100 * 1024;
64
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68 Tensor tensor;
69 if (!tensor.FromProto(proto)) {
70 return false;
71 }
72 auto values = tensor.flat<T>();
73 for (int i = 0; i < tensor.NumElements(); ++i) {
74 if (values(i) != value) {
75 return false;
76 }
77 }
78 return true;
79 }
80
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85 GraphDef* graph, NodeMap* node_map) {
86 bool already_exists = false;
87 for (const string& input : node->input()) {
88 if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89 already_exists = true;
90 break;
91 }
92 }
93 if (!already_exists) {
94 const string ctrl_dep =
95 ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96 node->add_input(ctrl_dep);
97 node_map->AddOutput(NodeName(ctrl_input), node->name());
98 }
99 return !already_exists;
100 }
101
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104 GraphDef* graph, NodeMap* node_map) {
105 bool removed_input = false;
106 bool update_node_map = true;
107 const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108 for (int i = 0; i < node->input_size(); ++i) {
109 const string& input = node->input(i);
110 if (old_input_ctrl_dep == input) {
111 if (IsControlInput(input)) {
112 node->mutable_input()->SwapElements(i, node->input_size() - 1);
113 node->mutable_input()->RemoveLast();
114 removed_input = true;
115 } else {
116 // There is a non-control input from the same node.
117 // Don't remove the output from the NodeMap.
118 update_node_map = false;
119 }
120 }
121 }
122 if (update_node_map) {
123 node_map->RemoveOutput(NodeName(old_input), node->name());
124 }
125 return removed_input;
126 }
127
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129 AttrSlice attrs(node);
130 for (const auto& attr : attrs) {
131 if (attr.first.find("_tpu_") != attr.first.npos) {
132 return true;
133 }
134 }
135 return false;
136 }
137
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140 return a != b;
141 }
142
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154 switch (data_type) {
155 case DT_QINT8:
156 return Eigen::NumTraits<qint8>::lowest();
157 case DT_QUINT8:
158 return Eigen::NumTraits<quint8>::lowest();
159 case DT_QINT16:
160 return Eigen::NumTraits<qint16>::lowest();
161 case DT_QUINT16:
162 return Eigen::NumTraits<quint16>::lowest();
163 case DT_QINT32:
164 return Eigen::NumTraits<qint32>::lowest();
165 default:
166 return 0.0f;
167 }
168 }
169
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171 switch (data_type) {
172 case DT_QINT8:
173 return Eigen::NumTraits<qint8>::highest();
174 case DT_QUINT8:
175 return Eigen::NumTraits<quint8>::highest();
176 case DT_QINT16:
177 return Eigen::NumTraits<qint16>::highest();
178 case DT_QUINT16:
179 return Eigen::NumTraits<quint16>::highest();
180 case DT_QINT32:
181 return Eigen::NumTraits<qint32>::highest();
182 default:
183 return 0.0f;
184 }
185 }
186
187 } // namespace
188
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190 DeviceBase* cpu_device,
191 bool disable_compressed_tensor_optimization,
192 bool fold_quantization_emulation)
193 : opt_level_(opt_level),
194 cpu_device_(cpu_device),
195 disable_compressed_tensor_optimization_(
196 disable_compressed_tensor_optimization),
197 fold_quantization_emulation_(fold_quantization_emulation) {
198 resource_mgr_.reset(new ResourceMgr());
199 }
200
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202 bool disable_compressed_tensor_optimization,
203 bool fold_quantization_ops)
204 : ConstantFolding(RewriterConfig::ON, cpu_device,
205 disable_compressed_tensor_optimization,
206 fold_quantization_ops) {}
207
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210 GraphDef* graph,
211 NodeMap* node_map) {
212 if (IsControlInput(input_name)) {
213 return input_name;
214 }
215 const NodeDef* node = node_map->GetNode(input_name);
216 // Sanity check for missing node.
217 if (!node) {
218 return input_name;
219 }
220 if (!IsSwitch(*node)) {
221 return AsControlDependency(*node);
222 } else {
223 // We can't anchor control dependencies directly on the switch node: unlike
224 // other nodes only one of the outputs of the switch node will be generated
225 // when the switch node is executed, and we need to make sure the control
226 // dependency is only triggered when the corresponding output is triggered.
227 // We start by looking for an identity node connected to the output of the
228 // switch node, and use it to anchor the control dependency.
229 for (const NodeDef* output : node_map->GetOutputs(node->name())) {
230 if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
231 if (IsSameInput(node->input(0), input_name)) {
232 return AsControlDependency(*output);
233 }
234 }
235 }
236 // We haven't found an existing node where we can anchor the control
237 // dependency: add a new identity node.
238 int port = 0;
239 string ctrl_dep_name = ParseNodeName(input_name, &port);
240 strings::StrAppend(&ctrl_dep_name, "_", port);
241 ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
242 const DataType output_type = node->attr().at("T").type();
243
244 NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
245 if (added_node == nullptr) {
246 added_node = graph->add_node();
247 added_node->set_name(ctrl_dep_name);
248 added_node->set_op("Identity");
249 added_node->set_device(node->device());
250
251 (*added_node->mutable_attr())["T"].set_type(output_type);
252 *added_node->add_input() = input_name;
253 node_map->AddNode(added_node->name(), added_node);
254 node_map->AddOutput(node->name(), added_node->name());
255 }
256 return AsControlDependency(*added_node);
257 }
258 }
259
260 // Forward inputs at the given indices to outputs and add a control dependency
261 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)262 bool ConstantFolding::ForwardInputs(NodeDef* node,
263 absl::Span<const int> inputs_to_forward) {
264 for (int input_idx : inputs_to_forward) {
265 if (input_idx < 0 || input_idx >= node->input_size()) {
266 return false;
267 }
268 }
269
270 const auto& tmp = node_map_->GetOutputs(node->name());
271 const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
272 bool updated_graph = false;
273 for (int input_idx : inputs_to_forward) {
274 const string& input = node->input(input_idx);
275 if (IsControlInput(input) && consumers.size() > 1) {
276 continue;
277 }
278 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
279 if (input_node == nullptr) {
280 LOG(ERROR) << "Bad input: " << input;
281 break;
282 }
283 // Update each consumer.
284 for (NodeDef* consumer : consumers) {
285 bool add_dep = false;
286 for (int consumer_input_idx = 0;
287 consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
288 const string& consumer_input = consumer->input(consumer_input_idx);
289 if (IsControlInput(consumer_input)) {
290 break;
291 }
292 // It is illegal to add control dependencies to _Retval nodes, so we
293 // can't bypass value producing `node` and forward inputs to `consumer`.
294 if (IsRetval(*consumer)) {
295 break;
296 }
297 int output_idx;
298 const string input_node_name =
299 ParseNodeName(consumer_input, &output_idx);
300 if (input_node_name == node->name() && output_idx == input_idx) {
301 consumer->set_input(consumer_input_idx, input);
302 // We will keep the input from the node through a control
303 // dependency, so we only need to add the consumer as an output
304 // for the input node.
305 node_map_->AddOutput(NodeName(input), consumer->name());
306 add_dep = true;
307 }
308 }
309 if (add_dep) {
310 consumer->add_input(AsControlDependency(node->name()));
311 updated_graph = true;
312 }
313 }
314 }
315
316 if (updated_graph) {
317 for (NodeDef* consumer : consumers) {
318 DedupControlInputs(consumer);
319 }
320 }
321 return updated_graph;
322 }
323
324 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64_t value,const DataType & type,const int index,Tensor * tensor)325 static Status PutValueIntoTensor(const int64_t value, const DataType& type,
326 const int index, Tensor* tensor) {
327 if (type == DT_INT32) {
328 if (value >= INT_MAX) {
329 return Status(error::INVALID_ARGUMENT, "int32 overflow");
330 }
331 tensor->flat<int32>()(index) = static_cast<int32>(value);
332 } else {
333 tensor->flat<int64>()(index) = value;
334 }
335 return Status::OK();
336 }
337
338 // Writes the given tensor shape into the given tensor.
339 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)340 static Status ConvertShapeToConstant(const string& op, const DataType& type,
341 const PartialTensorShape& shp,
342 Tensor* tensor) {
343 if (op == "Shape" || op == "ShapeN") {
344 *tensor = Tensor(type, TensorShape({shp.dims()}));
345 for (int i = 0; i < shp.dims(); ++i) {
346 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
347 }
348 } else if (op == "Size") {
349 int64_t size = 1;
350 for (int i = 0; i < shp.dims(); ++i) {
351 size *= shp.dim_size(i);
352 }
353 *tensor = Tensor(type, TensorShape({}));
354 TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
355 } else {
356 CHECK_EQ(op, "Rank");
357 *tensor = Tensor(type, TensorShape({}));
358 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
359 }
360 return Status::OK();
361 }
362
363 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const364 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
365 StringPiece suffix) const {
366 return node_map_->NodeExists(OptimizedNodeName(node, suffix));
367 }
368
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const369 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
370 StringPiece suffix) const {
371 return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
372 kConstantFoldingConst);
373 }
374
IsReallyConstant(const NodeDef & node) const375 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
376 if (!IsConstant(node)) {
377 return false;
378 }
379 // If the node is fed it's not constant anymore.
380 return feed_nodes_.find(node.name()) == feed_nodes_.end();
381 }
382
383 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)384 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
385 Tensor* tensor) {
386 const NodeDef* node = node_map_->GetNode(node_name_or_input);
387 return node != nullptr && IsReallyConstant(*node) &&
388 CheckAttrExists(*node, "value").ok() &&
389 tensor->FromProto(node->attr().at("value").tensor());
390 }
391
392 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)393 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
394 // We may add some nodes to the graph to encode control dependencies and hold
395 // the materialized shapes: there is no need to process these added nodes, so
396 // only iterate over the nodes of the input graph.
397 const int node_count = graph_->node_size();
398 for (int node_idx = 0; node_idx < node_count; ++node_idx) {
399 NodeDef* node = graph_->mutable_node(node_idx);
400 const string op = node->op();
401 if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
402 op != "TensorArraySizeV3") {
403 continue;
404 }
405 const std::vector<OpInfo::TensorProperties>& output =
406 properties.GetOutputProperties(node->name());
407 const std::vector<OpInfo::TensorProperties>& input =
408 properties.GetInputProperties(node->name());
409 if (input.empty() || output.empty()) {
410 continue;
411 }
412
413 if (op == "Shape" || op == "Size" || op == "Rank") {
414 CHECK_EQ(1, output.size());
415 CHECK_EQ(1, input.size());
416
417 const DataType type = output[0].dtype();
418 CHECK(type == DT_INT32 || type == DT_INT64);
419 const PartialTensorShape shape(input[0].shape());
420
421 if ((op != "Rank" && !shape.IsFullyDefined()) ||
422 (op == "Rank" && shape.unknown_rank())) {
423 continue;
424 }
425
426 Tensor constant_value(type);
427 if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
428 continue;
429 }
430
431 // TODO(rmlarsen): Remove this workaround for b/150861569
432 // The bug involves an expression of the form Shape(ExpandDims(x)
433 // with an incorrectly inferred zero-size first dimension.
434 if (op == "Shape") {
435 if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
436 }
437
438 // Repurpose the existing node to be the constant.
439 // Device placement is preserved.
440 graph_modified_ = true;
441 node->set_op("Const");
442 EraseRegularNodeAttributes(node);
443 (*node->mutable_attr())["dtype"].set_type(type);
444 constant_value.AsProtoTensorContent(
445 (*node->mutable_attr())["value"].mutable_tensor());
446
447 // Turn the data input into a control dependency: this is needed to
448 // ensure that the constant value will only be run in the
449 // cases where the shape/rank/size would have been run in
450 // the original graph.
451 string ctrl_dep =
452 AddControlDependency(node->input(0), graph_, node_map_.get());
453 node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
454 node->set_input(0, ctrl_dep);
455 // Done with the Shape/Size/Rank node, move to the next node.
456 continue;
457 }
458
459 if (op == "TensorArraySizeV3") {
460 const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
461 if (array->input_size() == 0 ||
462 (array->attr().count("dynamic_size") != 0 &&
463 array->attr().at("dynamic_size").b())) {
464 continue;
465 }
466 const NodeDef* array_size =
467 CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
468 if (IsReallyConstant(*array_size)) {
469 // Don't materialize 0 sizes to avoid triggering incorrect static
470 // checks. A 0 sized array that can't grow isn't useful anyway.
471 if (array_size->attr().count("value") == 0) {
472 continue;
473 }
474 const TensorProto& raw_val = array_size->attr().at("value").tensor();
475 if (raw_val.dtype() != DT_INT32) {
476 continue;
477 }
478 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
479 if (!value.FromProto(raw_val)) {
480 continue;
481 }
482 if (value.flat<int32>()(0) == 0) {
483 continue;
484 }
485
486 graph_modified_ = true;
487 node->set_op("Const");
488 *node->mutable_attr() = array_size->attr();
489 node->set_input(0, AsControlDependency(NodeName(node->input(0))));
490 node->set_input(1, AddControlDependency(NodeName(node->input(1)),
491 graph_, node_map_.get()));
492 }
493 continue;
494 }
495
496 // Handle ShapeN materialization case.
497 // It's possible that not all input tensors have known shapes.
498 CHECK_EQ(op, "ShapeN");
499 CHECK_EQ(input.size(), output.size());
500 const NodeDef* const shape_n_node = node;
501 for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
502 ++port_idx) {
503 const DataType type = output[port_idx].dtype();
504 CHECK(type == DT_INT32 || type == DT_INT64);
505 const PartialTensorShape shape(input[port_idx].shape());
506 if (!shape.IsFullyDefined()) {
507 continue;
508 }
509 Tensor constant_value(type);
510 auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
511 if (!status.ok()) {
512 continue;
513 }
514
515 // We make a copy because we mutate the nodes.
516 auto fanouts = node_map_->GetOutputs(shape_n_node->name());
517 // Find all nodes consuming this shape and connect them through the new
518 // constant node instead.
519 for (NodeDef* output : fanouts) {
520 // Track whether there are any direct edges left between shape_n_node
521 // and this output node after the transformation.
522 bool direct_edges_exist = false;
523 for (int k = 0; k < output->input_size(); ++k) {
524 int port;
525 const string node_name = ParseNodeName(output->input(k), &port);
526 if (node_name == shape_n_node->name() && port == port_idx) {
527 // Create a const node as ShapeN's output if not already.
528 const string const_name = OptimizedNodeName(
529 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
530 if (node_map_->GetNode(const_name) == nullptr) {
531 NodeDef* added_node = graph_->add_node();
532 added_node->set_name(const_name);
533 added_node->set_op("Const");
534 added_node->set_device(shape_n_node->device());
535 node_map_->AddNode(added_node->name(), added_node);
536 (*added_node->mutable_attr())["dtype"].set_type(type);
537 constant_value.AsProtoTensorContent(
538 (*added_node->mutable_attr())["value"].mutable_tensor());
539 // We add a control dependency to the original ShapeN node,
540 // so that the node will only be run if all inputs of the
541 // original ShapeN node are run.
542 string ctrl_dep = AddControlDependency(shape_n_node->name(),
543 graph_, node_map_.get());
544 *added_node->add_input() = ctrl_dep;
545 node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
546 }
547 *output->mutable_input(k) = const_name;
548 node_map_->AddOutput(const_name, output->name());
549 graph_modified_ = true;
550 }
551 if (node_name == shape_n_node->name() && port != port_idx) {
552 direct_edges_exist = true;
553 }
554 }
555 if (!direct_edges_exist) {
556 node_map_->RemoveOutput(node->name(), output->name());
557 }
558 }
559 }
560 }
561
562 return Status::OK();
563 }
564
565 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)566 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
567 BCast::Vec* shape, int64* min_id) {
568 if (shape_node.op() == "Shape") {
569 const std::vector<OpInfo::TensorProperties>& prop1 =
570 properties.GetInputProperties(shape_node.name());
571 if (prop1.size() != 1) {
572 return false;
573 }
574 const TensorShapeProto& shp = prop1[0].shape();
575 if (shp.unknown_rank()) {
576 return false;
577 }
578 for (const auto& dim : shp.dim()) {
579 shape->push_back(dim.size());
580 *min_id = std::min<int64>(*min_id, dim.size());
581 }
582 } else {
583 if (shape_node.attr().count("value") == 0) {
584 return false;
585 }
586 const TensorProto& raw_val = shape_node.attr().at("value").tensor();
587 if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
588 return false;
589 }
590 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
591 if (!value.FromProto(raw_val)) {
592 return false;
593 }
594 for (int j = 0; j < value.NumElements(); ++j) {
595 if (raw_val.dtype() == DT_INT64) {
596 shape->push_back(value.vec<int64>()(j));
597 } else {
598 shape->push_back(value.vec<int>()(j));
599 }
600 }
601 }
602 return true;
603 }
604 } // namespace
605
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)606 Status ConstantFolding::MaterializeBroadcastGradientArgs(
607 const NodeDef& node, const GraphProperties& properties) {
608 const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
609 const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
610 if (shape_node1 == nullptr ||
611 (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
612 shape_node2 == nullptr ||
613 (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
614 return Status::OK();
615 }
616
617 // Don't optimize this again if it was already optimized and folded.
618 if (OptimizedNodeExists(node, "-folded-1") ||
619 OptimizedNodeExists(node, "-folded-2")) {
620 return Status::OK();
621 }
622 int64_t min_id = 0;
623 BCast::Vec shape1;
624 if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
625 return Status::OK();
626 }
627 BCast::Vec shape2;
628 if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
629 return Status::OK();
630 }
631 // A value of -1 means we don't known anything about the dimension. Replace
632 // the -1 values with unique dimension ids since we don't want two '-1'
633 // dimensions to be considered equal.
634 for (auto& id : shape1) {
635 if (id == -1) {
636 id = --min_id;
637 }
638 }
639 for (auto& id : shape2) {
640 if (id == -1) {
641 id = --min_id;
642 }
643 }
644
645 // Beware: the reduction dimensions computed by the BCast class are valid iff
646 // we assume that two distinct symbolic dimensions can't be equal and a
647 // symbolic dimension can't be equal to 1. This is often but not always true,
648 // so to make this optimization safe we filter out these cases.
649 const int common_dims = std::min(shape1.size(), shape2.size());
650 for (int i = 0; i < common_dims; ++i) {
651 if (shape1[i] >= 0 && shape2[i] >= 0) {
652 continue;
653 }
654 if (shape1[i] != shape2[i]) {
655 // We're either dealing with 2 different symbolic dimensions or a symbolic
656 // and a know dimensions. We can't be sure whether both are equal or not,
657 // so we can't be sure whether we'll be broadcasting or not.
658 return Status::OK();
659 }
660 }
661 // These extra dims could be equal to 1, in which case there is no
662 // broadcasting. It could also be greater than 1, in which case there would
663 // be broadcasting. Since we don't know, we'll just punt.
664 for (int i = common_dims, end = shape1.size(); i < end; ++i) {
665 if (shape1[i] < 0) {
666 return Status::OK();
667 }
668 }
669 for (int i = common_dims, end = shape2.size(); i < end; ++i) {
670 if (shape2[i] < 0) {
671 return Status::OK();
672 }
673 }
674
675 BCast bcast(shape1, shape2);
676 if (!bcast.IsValid()) {
677 return Status::OK();
678 }
679
680 BCast::Vec reduce_dims[2];
681 reduce_dims[0] = bcast.grad_x_reduce_idx();
682 reduce_dims[1] = bcast.grad_y_reduce_idx();
683
684 TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
685 const DataType type = node.attr().at("T").type();
686 NodeDef* out[2];
687 for (int j = 0; j < 2; ++j) {
688 int reduction_indices = reduce_dims[j].size();
689 Tensor value(type, TensorShape({reduction_indices}));
690 for (int i = 0; i < reduction_indices; ++i) {
691 if (type == DT_INT32) {
692 value.vec<int32>()(i) = reduce_dims[j][i];
693 } else {
694 value.vec<int64>()(i) = reduce_dims[j][i];
695 }
696 }
697 string const_name =
698 OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
699 out[j] = node_map_->GetNode(const_name);
700 if (out[j] == nullptr) {
701 out[j] = graph_->add_node();
702 TF_RETURN_IF_ERROR(
703 CreateNodeDef(const_name, TensorValue(&value), out[j]));
704 out[j]->set_device(node.device());
705 node_map_->AddNode(const_name, out[j]);
706 string ctrl_dep =
707 AddControlDependency(node.name(), graph_, node_map_.get());
708 *out[j]->add_input() = ctrl_dep;
709 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
710 }
711 }
712
713 // We make a copy here since we might mutate the set.
714 const auto outputs = node_map_->GetOutputs(node.name());
715 for (NodeDef* output : outputs) {
716 for (int k = 0; k < output->input_size(); ++k) {
717 int port;
718 string node_name = ParseNodeName(output->input(k), &port);
719 if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
720 *output->mutable_input(k) = out[port]->name();
721 node_map_->UpdateInput(output->name(), node_name, out[port]->name());
722 }
723 }
724 }
725
726 return Status::OK();
727 }
728
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)729 Status ConstantFolding::MaterializeReductionIndices(
730 NodeDef* node, const GraphProperties& properties) {
731 if (node->input_size() < 2) {
732 return Status::OK();
733 }
734 const NodeDef* indices = node_map_->GetNode(node->input(1));
735 if (!indices || IsReallyConstant(*indices)) {
736 // The reduction indices are already constant, there's nothing to do.
737 return Status::OK();
738 }
739
740 const std::vector<OpInfo::TensorProperties>& input_props =
741 properties.GetInputProperties(node->name());
742 if (input_props.size() != 2) {
743 return Status::OK();
744 }
745 const OpInfo::TensorProperties& input_prop = input_props[0];
746 if (input_prop.shape().unknown_rank()) {
747 // We can't do anything if we don't know the rank of the input.
748 return Status::OK();
749 }
750 const int input_rank = input_prop.shape().dim_size();
751 if (input_rank < 1) {
752 // Unexpected graph, don't try to change it.
753 return Status::OK();
754 }
755 const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
756 DataType dtype = reduction_indices_prop.dtype();
757 if (dtype != DT_INT32 && dtype != DT_INT64) {
758 return Status::OK();
759 }
760 PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
761 const int num_reduction_indices = reduction_indices_shape.num_elements();
762
763 const std::vector<OpInfo::TensorProperties>& output_props =
764 properties.GetOutputProperties(node->name());
765 if (output_props.size() != 1) {
766 return Status::OK();
767 }
768 const OpInfo::TensorProperties& output_prop = output_props[0];
769 const int output_rank =
770 output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
771
772 bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
773 if (!full_reduction) {
774 // A full reduction will generate a tensor of one of the shapes
775 // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
776 // elements in the output of the reduction, we may deduce it from reshape
777 // nodes following it.
778 for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
779 full_reduction = false;
780 if (!IsReshape(*fanout)) {
781 return Status::OK();
782 }
783 const std::vector<OpInfo::TensorProperties>& reshape_props =
784 properties.GetOutputProperties(fanout->name());
785 if (reshape_props.size() != 1) {
786 return Status::OK();
787 }
788 const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
789 PartialTensorShape shape(reshape_prop.shape());
790 if (shape.num_elements() != 1) {
791 return Status::OK();
792 } else {
793 full_reduction = true;
794 }
795 }
796 if (!full_reduction) {
797 return Status::OK();
798 }
799 }
800
801 // We know it's a full reduction. We can generate the full set of indices to
802 // reduce as a constant node.
803 string const_name = OptimizedNodeName(*node, "-reduction_indices");
804 if (node_map_->GetNode(const_name)) {
805 return Status::OK();
806 }
807 NodeDef* reduction_indices = graph_->add_node();
808 Tensor value(dtype, TensorShape({input_rank}));
809 for (int i = 0; i < input_rank; ++i) {
810 if (dtype == DT_INT32) {
811 value.vec<int32>()(i) = i;
812 } else {
813 value.vec<int64>()(i) = i;
814 }
815 }
816 TF_RETURN_IF_ERROR(
817 CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
818
819 reduction_indices->set_device(node->device());
820 string ctrl_dep =
821 AddControlDependency(node->input(1), graph_, node_map_.get());
822 *reduction_indices->add_input() = ctrl_dep;
823 node_map_->AddNode(const_name, reduction_indices);
824 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
825
826 node->set_input(1, reduction_indices->name());
827 node_map_->UpdateInput(node->name(), indices->name(),
828 reduction_indices->name());
829
830 return Status::OK();
831 }
832
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)833 Status ConstantFolding::MaterializeConstantValuedNode(
834 NodeDef* node, const GraphProperties& properties) {
835 if (disable_compressed_tensor_optimization_) {
836 return Status::OK();
837 }
838 // Nodes that generate constant-valued outputs can be represented compactly in
839 // compressed format, regardless of their shape.
840 const std::vector<OpInfo::TensorProperties>& output_props =
841 properties.GetOutputProperties(node->name());
842 if (output_props.size() != 1) return Status::OK();
843 const auto& output_shape = output_props[0].shape();
844 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
845 return Status::OK();
846 }
847 if (IsFill(*node)) {
848 const auto output_dtype = output_props[0].dtype();
849 NodeDef* input_node = nullptr;
850 for (int i = 0; i < 2; ++i) {
851 input_node = node_map_->GetNode(NodeName(node->input(i)));
852 if (input_node == nullptr || !IsReallyConstant(*input_node)) {
853 return Status::OK();
854 }
855 }
856 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
857
858 // Copy the input tensor to the fill node, set the output shape and data
859 // type, and change the node type to Const.
860 TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
861 const TensorProto& input_tensor = input_node->attr().at("value").tensor();
862 if (!input_tensor.tensor_content().empty()) {
863 // Convert the value to repeated field format, so we can use the
864 // decompression mechanism to store only a single value in the constant
865 // node, even if the shape specified in the original Fill is large.
866 Tensor t;
867 if (!t.FromProto(input_tensor)) {
868 return errors::InvalidArgument(
869 "Could not construct Tensor form TensorProto in node: ",
870 input_node->name());
871 }
872 tensor->clear_tensor_content();
873 t.AsProtoField(tensor);
874 } else {
875 *tensor = input_tensor;
876 }
877 *(tensor->mutable_tensor_shape()) = output_shape;
878 (*node->mutable_attr())["dtype"].set_type(output_dtype);
879 node->mutable_attr()->erase("T");
880 node->mutable_attr()->erase("index_type");
881 node->set_op("Const");
882 for (int i = 0; i < 2; i++) {
883 // Change inputs to a control inputs.
884 const string ctrl_dep = AsControlDependency(node->input(i));
885 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
886 node->set_input(i, ctrl_dep);
887 }
888 graph_modified_ = true;
889 } else {
890 double value =
891 (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
892 if (value >= 0) {
893 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
894 value, properties, output_shape, node, graph_));
895 }
896 }
897 return Status::OK();
898 }
899
900 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)901 Status ConstantFolding::MaterializeOutputValues(
902 NodeDef* node, const GraphProperties& properties) {
903 const std::vector<OpInfo::TensorProperties>& output =
904 properties.GetOutputProperties(node->name());
905 if (output.size() != 1 || !output[0].has_value() ||
906 !IsFoldable(*node, &properties)) {
907 return Status::OK();
908 }
909
910 // If this is a trivial Identity node with a constant input, just route the
911 // input around it.
912 if (IsIdentity(*node)) {
913 NodeDef* input = node_map_->GetNode(node->input(0));
914 if (IsReallyConstant(*input)) {
915 std::vector<int> inputs_to_forward;
916 std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
917 graph_modified_ = ForwardInputs(node, inputs_to_forward);
918 return Status::OK();
919 }
920 }
921 // Repurpose the existing node to be the constant.
922 // Device placement is preserved.
923 TensorProto value_copy = output[0].value();
924 return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
925 node, graph_);
926 }
927
MaterializeConstants(const GraphProperties & properties)928 Status ConstantFolding::MaterializeConstants(
929 const GraphProperties& properties) {
930 const int node_count = graph_->node_size();
931 for (int i = 0; i < node_count; ++i) {
932 NodeDef& node = *graph_->mutable_node(i);
933 const string& op = node.op();
934 if (op == "BroadcastGradientArgs") {
935 TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
936 } else if (IsReduction(node)) {
937 TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
938 } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
939 TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
940 } else {
941 TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
942 }
943 }
944 return Status::OK();
945 }
946
IsFoldable(const NodeDef & node,const GraphProperties * properties)947 bool ConstantFolding::IsFoldable(const NodeDef& node,
948 const GraphProperties* properties) {
949 string key = strings::StrCat(node.name(), "/", node.op());
950 auto it = maybe_foldable_nodes_.find(key);
951 if (it == maybe_foldable_nodes_.end()) {
952 it = maybe_foldable_nodes_
953 .emplace(std::move(key), MaybeFoldable(node, properties))
954 .first;
955 }
956 if (!it->second) {
957 return false;
958 } else {
959 return IsFoldableUncached(node, properties);
960 }
961 }
962
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const963 bool ConstantFolding::IsFoldableUncached(
964 const NodeDef& node, const GraphProperties* properties) const {
965 // Folding not applicable to ops with no inputs.
966 if (node.input().empty()) {
967 return false;
968 }
969 // We can only fold nodes if all their inputs are known statically, except in
970 // the case of a merge node that propagate the first inputs that becomes
971 // available, and therefore only requires a single constant input to be
972 // foldable.
973 bool merge_has_constant_input = false;
974 const bool is_merge = IsMerge(node);
975 for (const auto& input : node.input()) {
976 if (IsControlInput(input)) {
977 continue;
978 }
979 const NodeDef* input_node = node_map_->GetNode(input);
980 if (!input_node) {
981 return false;
982 }
983 bool is_const = IsReallyConstant(*input_node);
984 if (is_const) {
985 // Don't fold strings constants for now since this causes problems with
986 // checkpointing.
987 if (input_node->attr().count("dtype") == 0 ||
988 input_node->attr().at("dtype").type() == DT_STRING) {
989 return false;
990 }
991 // Special case: If a Merge node has at least one constant input that
992 // does not depend on a control input, we can fold it.
993 merge_has_constant_input |= !HasControlInputs(*input_node);
994 } else if (!is_merge) {
995 return false;
996 }
997 }
998 if (is_merge && !merge_has_constant_input) return false;
999 if (disable_compressed_tensor_optimization_ &&
1000 (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
1001 return false;
1002
1003 // If we know the output shapes, make sure that the outputs are small enough
1004 // to materialize.
1005 if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1006 const std::vector<OpInfo::TensorProperties>& input_props =
1007 properties->GetInputProperties(node.name());
1008 const std::vector<OpInfo::TensorProperties>& output_props =
1009 properties->GetOutputProperties(node.name());
1010 // Compute total size of inputs.
1011 int64_t input_size_bytes = 0;
1012 for (const auto& input_prop : input_props) {
1013 const PartialTensorShape input_shape(input_prop.shape());
1014 if (input_shape.IsFullyDefined()) {
1015 input_size_bytes +=
1016 input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1017 }
1018 }
1019 for (const auto& output_prop : output_props) {
1020 const PartialTensorShape output_shape(output_prop.shape());
1021 if (output_shape.IsFullyDefined()) {
1022 const int64_t num_bytes =
1023 output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1024 if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1025 // Do not fold nodes if the in-memory size of output is too large.
1026 // Notice that this is not exactly the same check used in
1027 // CreateNodeDef() where the actual encoded size is checked.
1028 return false;
1029 }
1030 }
1031 }
1032 }
1033
1034 return true;
1035 }
1036
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1037 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1038 const GraphProperties* properties) const {
1039 // Skip constants, they're already folded
1040 if (IsConstant(node)) {
1041 return false;
1042 }
1043 // Don't fold stateful ops such as TruncatedNormal.
1044 if (!IsFreeOfSideEffect(node)) {
1045 return false;
1046 }
1047
1048 // Skips nodes that must be preserved except allowlisted nodes.
1049 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1050 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1051 return false;
1052 }
1053
1054 // Skip control flow nodes, they can't be folded.
1055 if (ModifiesFrameInfo(node)) {
1056 return false;
1057 }
1058
1059 // Skips ops that don't benefit from folding.
1060 if (IsPlaceholder(node)) {
1061 return false;
1062 }
1063 // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1064 // have a valid output when executed.
1065 if (IsFakeParam(node)) {
1066 return false;
1067 }
1068
1069 if (node.op() == "AccumulateNV2") {
1070 return false;
1071 }
1072 // Removing LoopCond nodes can screw up the partitioner.
1073 if (node.op() == "LoopCond") {
1074 return false;
1075 }
1076
1077 if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1078 return false;
1079 }
1080
1081 const string& op = node.op();
1082 if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1083 op.find("Reader") != string::npos) {
1084 return false;
1085 }
1086 if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1087 return false;
1088 }
1089
1090 // Don't fold nodes that contain TPU attributes.
1091 // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1092 // properly forward custom attributes, b/119051778.
1093 if (HasTPUAttributes(node)) {
1094 return false;
1095 }
1096
1097 const OpDef* op_def = nullptr;
1098 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1099 if (!status.ok()) {
1100 return false;
1101 }
1102 // Don't fold ops without outputs.
1103 if (op_def->output_arg_size() == 0) {
1104 return false;
1105 }
1106 // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1107 // TODO(rmlarsen): Only do this for XLA_* devices.
1108 for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1109 if (output_arg.type() == DT_VARIANT) {
1110 return false;
1111 }
1112 }
1113
1114 // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1115 // Such nodes could be introduced by an earlier constant folding pass and are
1116 // preserved in case users want to fetch their values; re-processing them
1117 // would lead to an error of adding a duplicated node to graph.
1118 const auto& outputs = node_map_->GetOutputs(node.name());
1119 if (outputs.empty() &&
1120 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1121 return false;
1122 }
1123 return true;
1124 }
1125
1126 namespace {
1127
1128 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
1129 case DTYPE: \
1130 t->add_##NAME##_val(static_cast<TYPE>(value)); \
1131 break;
1132
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1133 Status CreateConstantTensorAttrValue(DataType type, double value,
1134 const TensorShapeProto& shape,
1135 AttrValue* attr_tensor) {
1136 TensorProto* t = attr_tensor->mutable_tensor();
1137 t->set_dtype(type);
1138 *t->mutable_tensor_shape() = shape;
1139 switch (type) {
1140 case DT_HALF:
1141 t->add_half_val(
1142 Eigen::numext::bit_cast<uint16>(static_cast<Eigen::half>(value)));
1143 break;
1144 case DT_BFLOAT16:
1145 t->add_half_val(
1146 Eigen::numext::bit_cast<uint16>(static_cast<bfloat16>(value)));
1147 break;
1148 SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1149 SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1150 SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
1151 SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
1152 SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1153 SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1154 SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1155 SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1156 SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1157 SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1158 SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1159 SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1160 SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1161 SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1162 SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1163 SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1164 default:
1165 return errors::InvalidArgument(
1166 "Unsupported type in CreateConstantTensorAttrValue: ",
1167 DataTypeString(type));
1168 }
1169 return Status::OK();
1170 }
1171
1172 #undef SET_TENSOR_CAL_CASE
1173
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1174 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1175 const GraphProperties& properties) {
1176 DataType dtype = DT_INVALID;
1177 if (node.attr().count("T") == 1) {
1178 dtype = node.attr().at("T").type();
1179 } else if (node.attr().count("dtype") == 1) {
1180 dtype = node.attr().at("dtype").type();
1181 } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1182 dtype = DT_BOOL;
1183 } else {
1184 auto output_props = properties.GetOutputProperties(node.name());
1185 if (!output_props.empty()) {
1186 dtype = output_props[0].dtype();
1187 }
1188 }
1189 return dtype;
1190 }
1191
1192 // Checks whether the shape of the const input of the Mul op is valid to perform
1193 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1194 bool IsValidConstShapeForMulConvPushDown(
1195 const string& data_format, const TensorShapeProto& filter_shape,
1196 const TensorShapeProto& mul_const_input_shape) {
1197 // If the const is a scalar, or it has fewer or same number of dimensions
1198 // than the filter and it only has single element, the optimization should
1199 // work.
1200 if (mul_const_input_shape.dim_size() <=
1201 static_cast<int>(data_format.size()) &&
1202 TensorShape(mul_const_input_shape).num_elements() == 1) {
1203 return true;
1204 }
1205
1206 // Otherwise, check the eligibility according to data format.
1207 if (data_format == "NHWC" || data_format == "NDHWC") {
1208 TensorShapeProto new_filter_shape;
1209 if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1210 &new_filter_shape)) {
1211 return false;
1212 }
1213 if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1214 return false;
1215 }
1216 // Only the last dimension could be larger than one, since broadcasting over
1217 // the last dimension (the output channel) will result in invalid filter.
1218 for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1219 if (mul_const_input_shape.dim(i).size() > 1) return false;
1220 }
1221 return true;
1222 } else if (data_format == "NCHW" || data_format == "NCDHW") {
1223 // TODO(laigd): support NCHW and NCDHW (b/111214513).
1224 return false;
1225 }
1226 return false;
1227 }
1228
1229 } // namespace
1230
1231 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1232 Status ConstantFolding::CreateNodeDef(const string& name,
1233 const TensorValue& tensor, NodeDef* node,
1234 size_t original_size) {
1235 node->set_name(name);
1236 node->set_op("Const");
1237
1238 AttrValue attr_type;
1239 attr_type.set_type(tensor->dtype());
1240 node->mutable_attr()->insert({"dtype", attr_type});
1241
1242 AttrValue attr_tensor;
1243 TensorProto* t = attr_tensor.mutable_tensor();
1244 bool optimized = false;
1245 size_t encoded_size;
1246 // Use the packed representation whenever possible to avoid generating large
1247 // graphdefs. Moreover, avoid repeating the last values if they're equal.
1248 if (tensor->NumElements() > 4) {
1249 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE) \
1250 { \
1251 const auto* val_ptr = tensor->flat<TYPE>().data(); \
1252 auto last = *val_ptr; \
1253 int64 last_index = 0; \
1254 for (int64 i = 0; i < tensor->NumElements(); ++i) { \
1255 TYPE cur = *val_ptr++; \
1256 if (PackedValuesNotEqual(cur, last)) { \
1257 last = cur; \
1258 last_index = i; \
1259 } \
1260 } \
1261 encoded_size = (last_index + 1) * sizeof(FIELDTYPE); \
1262 if (encoded_size < kint32max) { \
1263 optimized = true; \
1264 t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1); \
1265 const auto* src_ptr = tensor->flat<TYPE>().data(); \
1266 auto* dst_ptr = \
1267 t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1268 std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr); \
1269 } \
1270 } \
1271 break
1272
1273 switch (tensor->dtype()) {
1274 case DT_FLOAT:
1275 POPULATE_TENSOR_PROTO(tensor, t, float, float);
1276 case DT_DOUBLE:
1277 POPULATE_TENSOR_PROTO(tensor, t, double, double);
1278 case DT_INT64:
1279 POPULATE_TENSOR_PROTO(tensor, t, int64_t, int64);
1280 case DT_UINT64:
1281 POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1282 case DT_INT32:
1283 POPULATE_TENSOR_PROTO(tensor, t, int32_t, int);
1284 case DT_UINT32:
1285 POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1286 case DT_INT16:
1287 POPULATE_TENSOR_PROTO(tensor, t, int16_t, int);
1288 case DT_UINT16:
1289 POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1290 case DT_INT8:
1291 POPULATE_TENSOR_PROTO(tensor, t, int8_t, int);
1292 case DT_UINT8:
1293 POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1294 case DT_BOOL:
1295 POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1296 default:
1297 /* Do nothing. */
1298 break;
1299 }
1300 }
1301 if (optimized) {
1302 // Also specify type and shape.
1303 t->set_dtype(tensor->dtype());
1304 tensor->shape().AsProto(t->mutable_tensor_shape());
1305 } else {
1306 // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1307 // DT_QUINT8
1308 tensor->AsProtoTensorContent(t);
1309 encoded_size = t->tensor_content().size();
1310 }
1311 node->mutable_attr()->insert({"value", attr_tensor});
1312
1313 if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1314 return errors::InvalidArgument(
1315 strings::StrCat("Can't fold ", name, ", its size would be too large (",
1316 encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1317 }
1318 return Status::OK();
1319 }
1320
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1321 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1322 const TensorVector& inputs,
1323 TensorVector* output) const {
1324 return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1325 resource_mgr_.get(), output);
1326 }
1327
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1328 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1329 std::vector<NodeDef>* outputs,
1330 bool* result_too_large) {
1331 TensorVector inputs;
1332 TensorVector output_tensors;
1333 auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1334 for (const auto& input : inputs) {
1335 delete input.tensor;
1336 }
1337 for (const auto& output : output_tensors) {
1338 if (output.tensor) {
1339 delete output.tensor;
1340 }
1341 }
1342 });
1343
1344 size_t total_inputs_size = 0;
1345 for (const auto& input : node.input()) {
1346 const TensorId input_tensor = ParseTensorName(input);
1347 if (input_tensor.index() < 0) {
1348 // Control dependency
1349 break;
1350 }
1351 const NodeDef* input_node = node_map_->GetNode(input);
1352 if (!IsReallyConstant(*input_node)) {
1353 return Status(error::INVALID_ARGUMENT,
1354 strings::StrCat("Can't fold ", node.name(), ", its ", input,
1355 " isn't constant"));
1356 }
1357 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1358 const TensorProto& raw_val = input_node->attr().at("value").tensor();
1359 if (raw_val.dtype() == DT_INVALID) {
1360 return Status(
1361 error::INVALID_ARGUMENT,
1362 strings::StrCat("A tensor in the input node, with TensorId of ",
1363 input_tensor.ToString(),
1364 " has a dtype of DT_INVALID."));
1365 }
1366 Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1367 if (!value->FromProto(raw_val)) {
1368 delete (value);
1369 return errors::InvalidArgument("Unable to make Tensor from proto for ",
1370 node.name(), " with shape ",
1371 raw_val.tensor_shape().DebugString());
1372 }
1373 inputs.emplace_back(value);
1374 total_inputs_size += value->TotalBytes();
1375 }
1376
1377 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1378 if (output_tensors.empty()) {
1379 return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1380 }
1381
1382 outputs->resize(output_tensors.size());
1383 for (size_t i = 0; i < output_tensors.size(); i++) {
1384 string node_name = OptimizedNodeName(node, "-folded");
1385 if (output_tensors.size() > 1) {
1386 node_name = strings::StrCat(node_name, "-", i);
1387 }
1388 if (output_tensors[i].tensor) {
1389 Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1390 total_inputs_size);
1391 if (!s.ok()) {
1392 *result_too_large = true;
1393 return s;
1394 }
1395 } else {
1396 // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1397 // switch that's not selected by the switch predicate).
1398 outputs->at(i) = NodeDef();
1399 }
1400 }
1401 return Status::OK();
1402 }
1403
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1404 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1405 // Merge nodes are special, in the sense that they execute as soon as one of
1406 // their input is ready. We can therefore fold a merge node iff it has at
1407 // least one constant input without control dependency.
1408 // We still need to ensure that the nodes in the fanin of the merge node are
1409 // scheduled. We'll therefore add a control dependency from the merge node
1410 // to the folded constant. We end up with:
1411 // * the merge node and its inputs are preserved as is
1412 // * a new constant node C1, driven by the merge node through a control
1413 // dependency, initialized to the value of the folded input
1414 // * a new constant node C2, driven by the merge node through a control
1415 // dependency, initialized to the index of the folded input
1416 // * the fanout of the merge nodes is rewired to be driven by either C1 or
1417 // C2.
1418 for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1419 const auto& input = node->input(input_index);
1420 if (IsControlInput(input)) {
1421 // Try the next input.
1422 continue;
1423 }
1424 NodeDef* input_node = node_map_->GetNode(input);
1425 if (!IsReallyConstant(*input_node)) {
1426 continue;
1427 }
1428 bool valid_input = true;
1429 for (const string& fanin_of_input : input_node->input()) {
1430 if (IsControlInput(fanin_of_input)) {
1431 valid_input = false;
1432 break;
1433 }
1434 }
1435 if (!valid_input) {
1436 // Try the next input
1437 continue;
1438 }
1439
1440 string const_out_name = OptimizedNodeName(*node, "_const");
1441 string const_index_name = OptimizedNodeName(*node, "_index");
1442 if (node_map_->GetNode(const_out_name) ||
1443 node_map_->GetNode(const_index_name)) {
1444 // Intended name already exists.
1445 return errors::AlreadyExists(
1446 strings::StrCat(const_out_name, " or ", const_index_name,
1447 " already present in the graph"));
1448 }
1449
1450 NodeDef* const_out = output_graph->add_node();
1451 *const_out = *input_node;
1452 const_out->set_name(const_out_name);
1453 const_out->set_device(node->device());
1454 *const_out->add_input() = AsControlDependency(*node);
1455 node_map_->AddNode(const_out->name(), const_out);
1456 node_map_->AddOutput(node->name(), const_out->name());
1457
1458 NodeDef* const_index = output_graph->add_node();
1459 const_index->set_op("Const");
1460 Tensor index(DT_INT32, TensorShape({}));
1461 index.flat<int32>()(0) = input_index;
1462 (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1463 index.AsProtoTensorContent(
1464 (*const_index->mutable_attr())["value"].mutable_tensor());
1465 const_index->set_name(const_index_name);
1466 const_index->set_device(node->device());
1467 *const_index->add_input() = AsControlDependency(*node);
1468 node_map_->AddNode(const_index->name(), const_index);
1469 node_map_->AddOutput(node->name(), const_index->name());
1470
1471 // We make a copy because we mutate the nodes.
1472 auto outputs = node_map_->GetOutputs(node->name());
1473 for (NodeDef* output : outputs) {
1474 for (int i = 0; i < output->input_size(); i++) {
1475 int port;
1476 string node_name = ParseNodeName(output->input(i), &port);
1477 if (node_name == node->name()) {
1478 if (port == 0) {
1479 *output->mutable_input(i) = const_out->name();
1480 node_map_->AddOutput(const_out->name(), output->name());
1481 } else if (port == 1) {
1482 *output->mutable_input(i) = const_index->name();
1483 node_map_->AddOutput(const_index->name(), output->name());
1484 } else {
1485 // This is a control dependency (or an invalid edge since the
1486 // merge node has only 2 outputs): preserve them.
1487 }
1488 }
1489 }
1490 }
1491 return Status::OK();
1492 }
1493 return Status::OK();
1494 }
1495
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1496 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1497 bool* result_too_large) {
1498 *result_too_large = false;
1499 if (IsMerge(*node)) {
1500 return FoldMergeNode(node, output_graph);
1501 }
1502
1503 std::vector<NodeDef> const_nodes;
1504 TF_RETURN_IF_ERROR(
1505 EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1506 VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1507
1508 NodeDef* constant_output = nullptr;
1509 for (int i = 0, end = const_nodes.size(); i < end; i++) {
1510 NodeDef* const_node = &const_nodes[i];
1511 VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1512 if (const_node->name().empty()) {
1513 // Dead output: we can't create a constant to encode its value, so we'll
1514 // just skip it. We'll preserve the edges that originate from that
1515 // output below to preserve the overall behavior of the graph wrt dead
1516 // edges.
1517 continue;
1518 }
1519
1520 // Returns `true` iff `const_node` already has control input named `input`.
1521 const auto is_duplicate_control_input = [&](const string& input) -> bool {
1522 auto it = absl::c_find(const_node->input(), input);
1523 return it != const_node->input().end();
1524 };
1525
1526 // Forward control dependencies.
1527 for (const string& input : node->input()) {
1528 // Forward control dependencies from folded node.
1529 if (IsControlInput(input)) {
1530 if (!is_duplicate_control_input(input)) {
1531 *const_node->add_input() = input;
1532 }
1533 }
1534
1535 // Forward control dependencies from constant inputs to folded node.
1536 if (!IsControlInput(input)) {
1537 NodeDef* input_node = node_map_->GetNode(input);
1538 for (const string& fanin_of_input : input_node->input()) {
1539 if (!is_duplicate_control_input(fanin_of_input)) {
1540 *const_node->add_input() = fanin_of_input;
1541 }
1542 }
1543 }
1544 }
1545
1546 // We rewrite the existing node if it only has a single output, and
1547 // create new nodes otherwise.
1548 if (const_nodes.size() == 1) {
1549 node->set_op("Const");
1550 // Note we need to clear the inputs in NodeMap before we clear the inputs
1551 // in the node, otherwise NodeMap would see empty inputs and effectively
1552 // does nothing.
1553 node_map_->RemoveInputs(node->name());
1554 node->clear_input();
1555 *node->mutable_input() = const_node->input();
1556 for (const auto& input : node->input()) {
1557 node_map_->AddOutput(NodeName(input), node->name());
1558 }
1559 *node->mutable_attr() = const_node->attr();
1560 break;
1561 } else {
1562 if (node_map_->GetNode(const_node->name())) {
1563 // Intended name already exists.
1564 return errors::AlreadyExists(strings::StrCat(
1565 const_node->name(), " already present in the graph"));
1566 }
1567 NodeDef* added_node = output_graph->add_node();
1568 *added_node = *const_node;
1569 added_node->set_device(node->device());
1570 node_map_->AddNode(added_node->name(), added_node);
1571 for (const auto& input : added_node->input()) {
1572 node_map_->AddOutput(NodeName(input), added_node->name());
1573 }
1574 // All the constant nodes encoding output values have the same control
1575 // dependencies (since these are the control dependencies of the node
1576 // we're trying to fold). Record one such constant node.
1577 constant_output = added_node;
1578 }
1579 }
1580
1581 if (const_nodes.size() > 1) {
1582 // We make a copy because we mutate the nodes.
1583 auto outputs = node_map_->GetOutputs(node->name());
1584 for (NodeDef* output : outputs) {
1585 for (int i = 0; i < output->input_size(); i++) {
1586 int port;
1587 string node_name = ParseNodeName(output->input(i), &port);
1588 if (node_name == node->name()) {
1589 if (port < 0) {
1590 // Propagate control dependencies if possible. If not, we'll just
1591 // preserve the existing control dependencies.
1592 if (constant_output != nullptr) {
1593 node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1594 constant_output->name());
1595 *output->mutable_input(i) = AsControlDependency(*constant_output);
1596 }
1597 } else if (port < static_cast<int>(const_nodes.size()) &&
1598 !const_nodes[port].name().empty()) {
1599 // Replace alive outputs with the corresponding constant.
1600 node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1601 const_nodes[port].name());
1602 *output->mutable_input(i) = const_nodes[port].name();
1603 } else {
1604 // Leave this edge alone.
1605 VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1606 << "[" << node->op() << "] to " << output->name() << ":"
1607 << i << "[" << output->op() << "]";
1608 }
1609 }
1610 }
1611 }
1612 outputs = node_map_->GetOutputs(node->name());
1613 if (outputs.empty() && has_fetch_ &&
1614 nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1615 node_map_->RemoveInputs(node->name());
1616 node->clear_input();
1617 }
1618 }
1619 return Status::OK();
1620 }
1621
FoldGraph(const GraphProperties & properties,GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1622 Status ConstantFolding::FoldGraph(
1623 const GraphProperties& properties, GraphDef* output,
1624 absl::flat_hash_set<string>* nodes_to_not_simplify) {
1625 std::unordered_set<string> processed_nodes;
1626 std::deque<NodeDef*> queue;
1627 for (int i = 0; i < graph_->node_size(); i++) {
1628 bool foldable = IsFoldable(graph_->node(i), &properties);
1629 VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
1630 if (foldable) {
1631 queue.push_back(graph_->mutable_node(i));
1632 }
1633 }
1634 while (!queue.empty()) {
1635 NodeDef* node = queue.front();
1636 queue.pop_front();
1637 if (processed_nodes.count(node->name())) {
1638 continue;
1639 }
1640 // We need to record a copy of output nodes before FoldNode() modifies it.
1641 // We also need to ensure that the fanout is sorted deterministically.
1642 std::vector<NodeDef*> fanout =
1643 node_map_->GetOutputsOrderedByNodeName(node->name());
1644 bool result_too_large = false;
1645 Status s = FoldNode(node, output, &result_too_large);
1646 processed_nodes.insert(node->name());
1647 if (!s.ok()) {
1648 VLOG(1) << "Failed to fold node " << node->DebugString()
1649 << "\nError message: " << s;
1650 if (result_too_large) {
1651 nodes_to_not_simplify->emplace(node->name());
1652 }
1653 } else {
1654 for (auto& output : fanout) {
1655 if (IsFoldable(*output, &properties)) {
1656 queue.push_back(output);
1657 }
1658 }
1659 }
1660 }
1661
1662 // Delete the newly created nodes that don't feed anything.
1663 std::vector<int> nodes_to_delete;
1664 for (int i = 0; i < output->node_size(); i++) {
1665 const auto& fanout = node_map_->GetOutputs(output->node(i).name());
1666 if (fanout.empty()) nodes_to_delete.push_back(i);
1667 }
1668 EraseNodesFromGraph(std::move(nodes_to_delete), output);
1669
1670 for (int i = 0; i < graph_->node_size(); ++i) {
1671 NodeDef* node = graph_->mutable_node(i);
1672 // If no fetch nodes is provided, we conservatively
1673 // move all nodes in the original graph to the output, in case users need
1674 // to fetch their values.
1675 const auto& fanout = node_map_->GetOutputs(node->name());
1676 if (!fanout.empty() || !has_fetch_ ||
1677 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1678 *(output->add_node()) = std::move(*node);
1679 }
1680 }
1681 return Status::OK();
1682 }
1683
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1684 bool ConstantFolding::IsSimplifiableReshape(
1685 const NodeDef& node, const GraphProperties& properties) const {
1686 if (!IsReshape(node)) {
1687 return false;
1688 }
1689 CHECK_LE(2, node.input_size());
1690 const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1691 if (!IsReallyConstant(*new_shape)) {
1692 return false;
1693 }
1694 TensorVector outputs;
1695 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1696 for (const auto& output : outputs) {
1697 delete output.tensor;
1698 }
1699 });
1700
1701 Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1702 if (!s.ok()) {
1703 return false;
1704 }
1705 CHECK_EQ(1, outputs.size());
1706
1707 const std::vector<OpInfo::TensorProperties>& props =
1708 properties.GetInputProperties(node.name());
1709 if (props.empty()) {
1710 return false;
1711 }
1712 const OpInfo::TensorProperties& prop = props[0];
1713 if (prop.dtype() == DT_INVALID) {
1714 return false;
1715 }
1716 const PartialTensorShape shape(prop.shape());
1717 if (!shape.IsFullyDefined()) {
1718 return false;
1719 }
1720
1721 PartialTensorShape new_dims;
1722 if (outputs[0]->dtype() == DT_INT32) {
1723 std::vector<int32> shp;
1724 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1725 int32_t dim = outputs[0]->flat<int32>()(i);
1726 shp.push_back(dim);
1727 }
1728 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1729 } else {
1730 std::vector<int64> shp;
1731 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1732 int64_t dim = outputs[0]->flat<int64>()(i);
1733 shp.push_back(dim);
1734 }
1735 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1736 }
1737
1738 return shape.IsCompatibleWith(new_dims);
1739 }
1740
1741 #define IS_VALUE_CASE(DTYPE, VALUE) \
1742 case DTYPE: \
1743 return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1744 node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1745
1746 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1747 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1748
IsOnes(const NodeDef & node) const1749 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1750 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1751 return false;
1752 }
1753 if (IsOnesLike(node)) return true;
1754 if (IsZerosLike(node)) return false;
1755 if (node.op() == "Fill") {
1756 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1757 return values != nullptr && IsOnes(*values);
1758 }
1759 if (node.op() != "Const") return false;
1760 if (node.attr().count("dtype") == 0) return false;
1761 const auto dtype = node.attr().at("dtype").type();
1762 switch (dtype) {
1763 IS_ONES_CASE(DT_BOOL);
1764 IS_ONES_CASE(DT_HALF);
1765 IS_ONES_CASE(DT_BFLOAT16);
1766 IS_ONES_CASE(DT_FLOAT);
1767 IS_ONES_CASE(DT_DOUBLE);
1768 IS_ONES_CASE(DT_COMPLEX64);
1769 IS_ONES_CASE(DT_COMPLEX128);
1770 IS_ONES_CASE(DT_UINT8);
1771 IS_ONES_CASE(DT_INT8);
1772 IS_ONES_CASE(DT_UINT16);
1773 IS_ONES_CASE(DT_INT16);
1774 IS_ONES_CASE(DT_INT32);
1775 IS_ONES_CASE(DT_INT64);
1776 IS_ONES_CASE(DT_QINT32);
1777 IS_ONES_CASE(DT_QINT16);
1778 IS_ONES_CASE(DT_QUINT16);
1779 IS_ONES_CASE(DT_QINT8);
1780 IS_ONES_CASE(DT_QUINT8);
1781 default:
1782 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1783 return false;
1784 }
1785 return false;
1786 }
1787
IsZeros(const NodeDef & node) const1788 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1789 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1790 return false;
1791 }
1792 if (IsOnesLike(node)) return false;
1793 if (IsZerosLike(node)) return true;
1794 if (node.op() == "Fill") {
1795 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1796 return values != nullptr && IsZeros(*values);
1797 }
1798 if (!IsConstant(node)) return false;
1799 if (node.attr().count("dtype") == 0) return false;
1800 const auto dtype = node.attr().at("dtype").type();
1801 switch (dtype) {
1802 IS_ZEROS_CASE(DT_BOOL);
1803 IS_ZEROS_CASE(DT_HALF);
1804 IS_ZEROS_CASE(DT_BFLOAT16);
1805 IS_ZEROS_CASE(DT_FLOAT);
1806 IS_ZEROS_CASE(DT_DOUBLE);
1807 IS_ZEROS_CASE(DT_COMPLEX64);
1808 IS_ZEROS_CASE(DT_COMPLEX128);
1809 IS_ZEROS_CASE(DT_UINT8);
1810 IS_ZEROS_CASE(DT_INT8);
1811 IS_ZEROS_CASE(DT_UINT16);
1812 IS_ZEROS_CASE(DT_INT16);
1813 IS_ZEROS_CASE(DT_INT32);
1814 IS_ZEROS_CASE(DT_INT64);
1815 IS_ZEROS_CASE(DT_QINT32);
1816 IS_ZEROS_CASE(DT_QINT16);
1817 IS_ZEROS_CASE(DT_QUINT16);
1818 IS_ZEROS_CASE(DT_QINT8);
1819 IS_ZEROS_CASE(DT_QUINT8);
1820 default:
1821 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1822 return false;
1823 }
1824 return false;
1825 }
1826
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1827 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1828 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1829 GraphDef* graph) {
1830 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1831 if (dtype == DT_INVALID) {
1832 return false;
1833 }
1834 const PartialTensorShape shape(
1835 properties.GetOutputProperties(node->name())[0].shape());
1836 if (!shape.IsFullyDefined()) {
1837 return false;
1838 }
1839 // Create constant node with shape.
1840 const string const_name = OptimizedNodeName(
1841 *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1842 if (node_map_->GetNode(const_name) != nullptr) {
1843 return false;
1844 }
1845
1846 Tensor shape_t;
1847 if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1848 return false;
1849 }
1850 NodeDef tmp;
1851 if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1852 return false;
1853 }
1854 NodeDef* const_node = graph->add_node();
1855 const_node->Swap(&tmp);
1856 const_node->set_device(node->device());
1857 node_map_->AddNode(const_name, const_node);
1858 for (int i = 0; i < node->input_size(); ++i) {
1859 if (i != input_to_broadcast) {
1860 // Add a control input on the unused input.
1861 string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1862 node_map_.get());
1863 *const_node->add_input() = ctrl_dep;
1864 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1865 }
1866 }
1867
1868 // Rewrite `node` in-place to BroadcastTo.
1869 node->set_op("BroadcastTo");
1870 EraseRegularNodeAttributes(node);
1871 (*node->mutable_attr())["T"].set_type(dtype);
1872 (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1873 // Set the designated input to BroadcastTo.
1874 node->mutable_input()->SwapElements(0, input_to_broadcast);
1875 // Keep all other inputs as control dependencies.
1876 for (int i = 1; i < node->input_size(); ++i) {
1877 if (IsControlInput(node->input(i))) {
1878 break;
1879 }
1880 const string ctrl_dep =
1881 AddControlDependency(node->input(i), graph, node_map_.get());
1882 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1883 node->set_input(i, ctrl_dep);
1884 }
1885 // Add the shape argument.
1886 *node->add_input() = const_node->name();
1887 node_map_->AddOutput(const_name, node->name());
1888 node->mutable_input()->SwapElements(1, node->input_size() - 1);
1889 return true;
1890 }
1891
1892 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1893 void ConstantFolding::ReplaceOperationWithIdentity(
1894 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1895 GraphDef* graph) {
1896 if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1897 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1898 if (dtype == DT_INVALID) return;
1899
1900 node->set_op("Identity");
1901 EraseRegularNodeAttributes(node);
1902 (*node->mutable_attr())["T"].set_type(dtype);
1903 // Propagate the designated input through the identity.
1904 node->mutable_input()->SwapElements(0, input_to_forward);
1905 // Add all other inputs as control dependencies.
1906 for (int i = 1; i < node->input_size(); ++i) {
1907 if (IsControlInput(node->input(i))) {
1908 break;
1909 }
1910 const string ctrl_dep =
1911 AddControlDependency(node->input(i), graph, node_map_.get());
1912 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1913 node->set_input(i, ctrl_dep);
1914 }
1915 graph_modified_ = true;
1916 }
1917
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1918 void ConstantFolding::ReplaceOperationWithSnapshot(
1919 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1920 GraphDef* graph) {
1921 // If the graph contains no ops that mutate their inputs, we can
1922 // use Identity instead of Snapshot.
1923 if (!graph_contains_assign_or_inplace_op_) {
1924 ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1925 return;
1926 }
1927
1928 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1929 if (dtype == DT_INVALID) return;
1930
1931 node->set_op("Snapshot");
1932 EraseRegularNodeAttributes(node);
1933 (*node->mutable_attr())["T"].set_type(dtype);
1934 // Propagate the designated input through the Snapshot.
1935 node->mutable_input()->SwapElements(0, input_to_forward);
1936 // Add all other inputs as control dependencies.
1937 for (int i = 1; i < node->input_size(); ++i) {
1938 if (IsControlInput(node->input(i))) {
1939 break;
1940 }
1941 const string ctrl_dep =
1942 AddControlDependency(node->input(i), graph, node_map_.get());
1943 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1944 node->set_input(i, ctrl_dep);
1945 }
1946 graph_modified_ = true;
1947 }
1948
1949 // Replace a node with NoOp. Change all inputs to control dependencies.
1950 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1951 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1952 GraphProperties* properties,
1953 GraphDef* graph) {
1954 if (HasRegularOutputs(*node, *node_map_)) return;
1955 node->set_op("NoOp");
1956 EraseRegularNodeAttributes(node);
1957 EraseNodeOutputAttributes(node);
1958 // Erase attributes that describe output properties.
1959 properties->ClearOutputProperties(node->name());
1960 // Change all inputs to control dependencies.
1961 for (int i = 0; i < node->input_size(); ++i) {
1962 if (IsControlInput(node->input(i))) {
1963 break;
1964 }
1965 const string ctrl_dep =
1966 AddControlDependency(node->input(i), graph, node_map_.get());
1967 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1968 node->set_input(i, ctrl_dep);
1969 }
1970 DedupControlInputs(node);
1971 graph_modified_ = true;
1972 }
1973
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1974 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
1975 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1976 GraphDef* graph) {
1977 if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
1978 graph)) {
1979 return;
1980 }
1981 graph_modified_ = true;
1982 }
1983
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1984 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1985 GraphDef* graph) {
1986 node->set_op("Reciprocal");
1987 node->mutable_input()->SwapElements(0, 1);
1988 const string ctrl_dep =
1989 AddControlDependency(node->input(1), graph, node_map_.get());
1990 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1991 node->set_input(1, ctrl_dep);
1992 graph_modified_ = true;
1993 }
1994
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1995 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1996 GraphDef* graph) {
1997 node->set_op("Neg");
1998 node->mutable_input()->SwapElements(0, 1);
1999 const string ctrl_dep =
2000 AddControlDependency(node->input(1), graph, node_map_.get());
2001 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2002 node->set_input(1, ctrl_dep);
2003 graph_modified_ = true;
2004 }
2005
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)2006 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
2007 TensorProto* value,
2008 NodeDef* node,
2009 GraphDef* graph) {
2010 if (dtype == DT_VARIANT) return Status::OK();
2011 node->set_op("Const");
2012 EraseRegularNodeAttributes(node);
2013 (*node->mutable_attr())["dtype"].set_type(dtype);
2014 (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
2015 // Convert all inputs to control dependencies.
2016 for (int i = 0; i < node->input_size(); ++i) {
2017 if (IsControlInput(node->input(i))) {
2018 break;
2019 }
2020 const string ctrl_dep =
2021 AddControlDependency(node->input(i), graph, node_map_.get());
2022 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2023 node->set_input(i, ctrl_dep);
2024 }
2025 DedupControlInputs(node);
2026 graph_modified_ = true;
2027 return Status::OK();
2028 }
2029
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2030 Status ConstantFolding::ReplaceOperationWithConstant(
2031 double value, const GraphProperties& properties,
2032 const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2033 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2034 if (dtype == DT_VARIANT) return Status::OK();
2035 AttrValue tensor_attr;
2036 Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2037 if (!s.ok()) {
2038 // Fail gracefully without mutating the graph.
2039 VLOG(1) << "Failed to replace node " << node->name() << " of type "
2040 << DataTypeString(dtype) << " with constant tensor of value "
2041 << value;
2042 return Status::OK();
2043 }
2044 return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2045 node, graph);
2046 }
2047
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2048 Status ConstantFolding::SimplifyGraph(
2049 bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
2050 absl::flat_hash_set<string>* nodes_to_not_simplify) {
2051 for (int i = 0; i < optimized_graph->node_size(); ++i) {
2052 NodeDef* node = optimized_graph->mutable_node(i);
2053 // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2054 // generalize to only restrict certain simplifications.
2055 if (nodes_to_not_simplify->find(node->name()) ==
2056 nodes_to_not_simplify->end()) {
2057 if (HasTPUAttributes(*node)) {
2058 nodes_to_not_simplify->insert(node->name());
2059 continue;
2060 }
2061
2062 TF_RETURN_IF_ERROR(
2063 SimplifyNode(use_shape_info, node, optimized_graph, properties));
2064 }
2065 }
2066 return Status::OK();
2067 }
2068
2069 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2070 TF_RETURN_IF_ERROR(EXPR); \
2071 if (graph_modified_) return Status::OK()
2072
2073 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2074 graph_modified_ = EXPR; \
2075 if (graph_modified_) return Status::OK()
2076
2077 #define RETURN_IF_MODIFIED(EXPR) \
2078 EXPR; \
2079 if (graph_modified_) return Status::OK()
2080
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2081 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
2082 GraphDef* optimized_graph,
2083 GraphProperties* properties) {
2084 bool graph_modified_cached = graph_modified_;
2085 graph_modified_ = false;
2086
2087 RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2088 RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2089 *properties, use_shape_info, optimized_graph, node));
2090 RETURN_IF_MODIFIED(
2091 RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2092 RETURN_IF_ERROR_OR_MODIFIED(
2093 RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2094 RETURN_IF_ERROR_OR_MODIFIED(
2095 SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2096 RETURN_IF_ERROR_OR_MODIFIED(
2097 SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2098 RETURN_IF_ERROR_OR_MODIFIED(
2099 SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2100 RETURN_IF_ERROR_OR_MODIFIED(
2101 SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2102 RETURN_IF_MODIFIED(
2103 SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2104 SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2105 SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2106 SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2107 SET_AND_RETURN_IF_MODIFIED(
2108 SimplifyReduction(optimized_graph, *properties, node));
2109 SET_AND_RETURN_IF_MODIFIED(
2110 SimplifyReshape(*properties, use_shape_info, node));
2111 RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2112 *properties, use_shape_info, optimized_graph, node));
2113 SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2114 SET_AND_RETURN_IF_MODIFIED(
2115 ConstantPushDown(properties, optimized_graph, node));
2116 SET_AND_RETURN_IF_MODIFIED(
2117 MulConvPushDown(optimized_graph, node, *properties));
2118 SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2119 SET_AND_RETURN_IF_MODIFIED(
2120 PartialAssocOpConstFolding(optimized_graph, properties, node));
2121 SET_AND_RETURN_IF_MODIFIED(
2122 MergeConcat(use_shape_info, properties, optimized_graph, node));
2123 SET_AND_RETURN_IF_MODIFIED(
2124 PartialConcatConstFolding(optimized_graph, properties, node));
2125 SET_AND_RETURN_IF_MODIFIED(
2126 ConstantPushDownBiasAdd(properties, optimized_graph, node));
2127 SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2128 SET_AND_RETURN_IF_MODIFIED(
2129 SimplifySelect(*properties, optimized_graph, node));
2130 RETURN_IF_MODIFIED(
2131 RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2132
2133 graph_modified_ = graph_modified_cached;
2134 return Status::OK();
2135 }
2136
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2137 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2138 GraphDef* optimized_graph,
2139 NodeDef* node) {
2140 if (node->attr().count("num_split") == 0) return;
2141 if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2142 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2143 }
2144 if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2145 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2146 }
2147 }
2148
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2149 Status ConstantFolding::RemoveShuffleOrTranspose(
2150 const GraphProperties& properties, bool use_shape_info,
2151 GraphDef* optimized_graph, NodeDef* node) {
2152 if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2153 return Status::OK();
2154 Tensor permutation_tensor;
2155 if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2156 properties.HasInputProperties(node->name())) {
2157 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2158 std::vector<int> permutation;
2159 for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2160 if (permutation_tensor.dtype() == DT_INT64) {
2161 permutation.push_back(permutation_tensor.vec<int64>()(j));
2162 } else {
2163 permutation.push_back(permutation_tensor.vec<int>()(j));
2164 }
2165 }
2166 int permutation_size = permutation.size();
2167 if (permutation_size != shape.dim_size()) {
2168 // Number of elements in perm should be same as dim_size. Skip if not.
2169 return Status::OK();
2170 }
2171 // The node is replaceable iff
2172 // dim_size == 0 || all dims have size 1 ||
2173 // all dims with > 1 size are not permuted.
2174 bool replaceable = true;
2175 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2176 replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2177 }
2178 if (replaceable) {
2179 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2180 }
2181 }
2182 return Status::OK();
2183 }
2184
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2185 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2186 bool use_shape_info,
2187 GraphDef* optimized_graph,
2188 NodeDef* node) {
2189 if (use_shape_info && IsRandomShuffle(*node) &&
2190 !properties.GetInputProperties(node->name()).empty()) {
2191 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2192 // The node is replaceable iff
2193 // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2194 if (!shape.unknown_rank() &&
2195 (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2196 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2197 }
2198 }
2199 }
2200
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2201 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2202 bool use_shape_info,
2203 GraphDef* optimized_graph,
2204 NodeDef* node) {
2205 if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
2206 Tensor axis;
2207 if (properties.HasInputProperties(node->name()) &&
2208 GetTensorFromConstNode(node->input(1), &axis)) {
2209 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2210 if (shape.unknown_rank()) return Status::OK();
2211 std::set<int> target_axes;
2212 for (int j = 0; j < axis.NumElements(); ++j) {
2213 // value of axis can be negative.
2214 if (axis.dtype() == DT_INT64) {
2215 target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2216 shape.dim_size());
2217 } else {
2218 target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2219 shape.dim_size());
2220 }
2221 }
2222
2223 // The node is replaceable iff
2224 // unknown_rank == false &&
2225 // (dim_size == 0 || all dims have size 1 ||
2226 // all dims with > 1 size are not in target_axes)
2227 bool replaceable = true;
2228 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2229 replaceable &=
2230 shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2231 }
2232 if (replaceable) {
2233 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2234 }
2235 }
2236 return Status::OK();
2237 }
2238
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2239 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2240 bool use_shape_info,
2241 GraphDef* optimized_graph,
2242 NodeDef* node) {
2243 if (!use_shape_info || !IsSlice(*node)) return Status::OK();
2244 Tensor begin;
2245 Tensor size;
2246 if (properties.HasInputProperties(node->name()) &&
2247 GetTensorFromConstNode(node->input(1), &begin) &&
2248 GetTensorFromConstNode(node->input(2), &size)) {
2249 const auto& input = properties.GetInputProperties(node->name())[0];
2250 // The node is replaceable iff unknown_rank == false &&
2251 // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2252 bool replaceable = !input.shape().unknown_rank();
2253 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2254 if (begin.dtype() == DT_INT32) {
2255 replaceable &= begin.vec<int>()(j) == 0;
2256 } else {
2257 replaceable &= begin.vec<int64>()(j) == 0;
2258 }
2259 if (size.dtype() == DT_INT32) {
2260 replaceable &= (size.vec<int>()(j) == -1 ||
2261 size.vec<int>()(j) == input.shape().dim(j).size());
2262 } else {
2263 replaceable &= (size.vec<int64>()(j) == -1 ||
2264 size.vec<int64>()(j) == input.shape().dim(j).size());
2265 }
2266 }
2267 if (replaceable) {
2268 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2269 }
2270 }
2271 return Status::OK();
2272 }
2273
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2274 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2275 bool use_shape_info,
2276 GraphDef* optimized_graph,
2277 NodeDef* node) {
2278 if (use_shape_info && IsStridedSlice(*node) &&
2279 properties.GetInputProperties(node->name()).size() == 4) {
2280 TF_RETURN_IF_ERROR(
2281 CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2282 if (node->attr().at("new_axis_mask").i() != 0 ||
2283 node->attr().at("shrink_axis_mask").i() != 0) {
2284 // Skip nodes with new/shrink axis mask, since they involve dimension
2285 // changes.
2286 return Status::OK();
2287 }
2288 const auto& input = properties.GetInputProperties(node->name())[0];
2289 for (int j = 0; j < input.shape().dim_size(); ++j) {
2290 // Skip if input shape is not fully determined.
2291 if (input.shape().dim(j).size() < 0) {
2292 return Status::OK();
2293 }
2294 }
2295
2296 std::vector<Tensor> input_tensors(3);
2297 for (int i = 1; i < 4; ++i) {
2298 if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2299 return Status::OK();
2300 }
2301 }
2302
2303 const Tensor& begin = input_tensors[0];
2304 const Tensor& end = input_tensors[1];
2305 const Tensor& strides = input_tensors[2];
2306
2307 TF_RETURN_IF_ERROR(
2308 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2309 int begin_mask = node->attr().at("begin_mask").i();
2310 int end_mask = node->attr().at("end_mask").i();
2311 std::set<int> expanded_ellipsis_indices;
2312 int ellipsis_index = -1;
2313 for (int j = 0; j < input.shape().dim_size(); ++j) {
2314 // find the ellipsis_mask. If not found, insert one in the end if
2315 // necessary.
2316 if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2317 (ellipsis_index == -1 && j >= strides.NumElements())) {
2318 ellipsis_index = j;
2319 }
2320 // insert the indices that are immediately after ellipsis_index if
2321 // necessary.
2322 if (ellipsis_index != -1 &&
2323 input.shape().dim_size() >
2324 strides.NumElements() + j - ellipsis_index) {
2325 expanded_ellipsis_indices.insert(j);
2326 }
2327 }
2328
2329 // The node is replaceable iff unknown_rank == false &&
2330 // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2331 // && strides == 1) for all dimensions.
2332 bool replaceable = !input.shape().unknown_rank();
2333 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2334 if (expanded_ellipsis_indices.find(j) !=
2335 expanded_ellipsis_indices.end()) {
2336 // ellipsis_mask is effective on current dimension.
2337 continue;
2338 }
2339 // when we have ellipsis_mask in between, input.shape().dim_size() will
2340 // be greater than strides.NumElements(), since we will insert
2341 // as many as expanded_ellipsis_indices.size() axes during computation.
2342 // We need to subtract this number from j.
2343 int i = j;
2344 int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2345 if (ellipsis_index != -1 &&
2346 j >= ellipsis_index + expanded_ellipsis_indices_size) {
2347 i = j - expanded_ellipsis_indices_size;
2348 }
2349 int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2350 : begin.vec<int64>()(i);
2351 int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2352 int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2353 : strides.vec<int64>()(i);
2354 replaceable &= (begin_mask & 1 << i || b == 0) &&
2355 (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2356 s == 1;
2357 }
2358 if (replaceable) {
2359 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2360 }
2361 }
2362 return Status::OK();
2363 }
2364
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2365 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2366 bool use_shape_info,
2367 GraphDef* optimized_graph, NodeDef* node) {
2368 Tensor multiplies;
2369 if (use_shape_info && IsTile(*node) &&
2370 GetTensorFromConstNode(node->input(1), &multiplies)) {
2371 // The node is replaceable iff all values in multiplies are 1.
2372 bool replaceable = true;
2373 if (multiplies.dtype() == DT_INT32) {
2374 for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2375 replaceable &= multiplies.vec<int>()(j) == 1;
2376 }
2377 } else {
2378 for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
2379 replaceable &= multiplies.vec<int64>()(j) == 1;
2380 }
2381 }
2382 if (replaceable) {
2383 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2384 }
2385 }
2386 return Status::OK();
2387 }
2388
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2389 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2390 bool use_shape_info,
2391 GraphDef* optimized_graph, NodeDef* node) {
2392 if (!use_shape_info || !IsPad(*node)) return Status::OK();
2393
2394 Tensor paddings;
2395 if (GetTensorFromConstNode(node->input(1), &paddings)) {
2396 // The node is replaceable iff all values in paddings are 0.
2397 bool replaceable = true;
2398 if (paddings.dtype() == DT_INT32) {
2399 const auto flatten = paddings.flat<int32>();
2400 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2401 replaceable &= flatten(j) == 0;
2402 }
2403 } else {
2404 const auto flatten = paddings.flat<int64>();
2405 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2406 replaceable &= flatten(j) == 0;
2407 }
2408 }
2409 if (replaceable) {
2410 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2411 }
2412 }
2413 return Status::OK();
2414 }
2415
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2416 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2417 bool use_shape_info,
2418 GraphDef* optimized_graph,
2419 NodeDef* node) {
2420 if (use_shape_info && IsSqueeze(*node) &&
2421 !properties.GetInputProperties(node->name()).empty()) {
2422 // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2423 // error to squeeze a dimension that is not 1, so we only need to check
2424 // whether the input has > 1 size for each dimension.
2425 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2426 // The node is replaceable iff
2427 // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2428 bool replaceable = !shape.unknown_rank();
2429 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2430 replaceable &= shape.dim(j).size() > 1;
2431 }
2432 if (replaceable) {
2433 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2434 }
2435 }
2436 }
2437
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2438 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2439 const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2440 if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2441 node_map_->NodeExists(axis_node_name)) {
2442 return false;
2443 }
2444
2445 // It's unsafe to add a control dependency on the feed node, because it might
2446 // have been never executed otherwiwise.
2447 if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2448 return false;
2449 }
2450
2451 // Create constant axis node.
2452 Tensor axis_t(DT_INT32, TensorShape({}));
2453 const int axis =
2454 node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2455 NodeDef new_node;
2456 if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2457 !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2458 return false;
2459 }
2460 NodeDef* axis_node = optimized_graph->add_node();
2461 *axis_node = std::move(new_node);
2462 axis_node->set_name(axis_node_name);
2463 node_map_->AddNode(axis_node->name(), axis_node);
2464 // Add a control dependency to make sure axis_node is in the right frame.
2465 const string ctrl_dep = ConstantFolding::AddControlDependency(
2466 node->input(0), optimized_graph, node_map_.get());
2467 axis_node->add_input(ctrl_dep);
2468 axis_node->set_device(node->device());
2469 node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2470 node->set_op("ExpandDims");
2471 if (node->attr().count("axis") != 0) {
2472 node->mutable_attr()->erase("axis");
2473 }
2474 if (node->attr().count("N") != 0) {
2475 node->mutable_attr()->erase("N");
2476 }
2477 (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2478 node->add_input(axis_node->name());
2479 node_map_->AddOutput(axis_node->name(), node->name());
2480 if (node->input_size() > 2) {
2481 node->mutable_input()->SwapElements(1, node->input_size() - 1);
2482 }
2483 return true;
2484 }
2485
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2486 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2487 if (node->op() != "Case") return false;
2488 const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2489 if (output_idx_node == nullptr ||
2490 !CheckAttrExists(*output_idx_node, "value").ok()) {
2491 return false;
2492 }
2493 Tensor output_idx_t;
2494 if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2495 return false;
2496 int output_idx = output_idx_t.scalar<int>()();
2497 const auto& func_list = node->attr().at("branches").list();
2498 if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2499 NodeDef call_node = *node;
2500 call_node.set_op("PartitionedCall");
2501 call_node.clear_input();
2502 for (int i = 1; i < node->input_size(); ++i) {
2503 call_node.add_input(node->input(i));
2504 }
2505 auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2506 *new_func = func_list.func(output_idx);
2507
2508 // Move the output shape of the branch to _output_shapes if it is known.
2509 const auto& output_shape_list =
2510 (*node->mutable_attr())["output_shapes"].list();
2511 if (output_shape_list.shape_size() > output_idx) {
2512 TensorShapeProto* new_output_shape =
2513 (*call_node.mutable_attr())["_output_shapes"]
2514 .mutable_list()
2515 ->add_shape();
2516 *new_output_shape =
2517 std::move(node->attr().at("output_shapes").list().shape(output_idx));
2518 }
2519
2520 call_node.mutable_attr()->erase("output_shapes");
2521 call_node.mutable_attr()->erase("branches");
2522
2523 *node = std::move(call_node);
2524 return true;
2525 }
2526
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2527 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2528 GraphDef* optimized_graph, NodeDef* node) {
2529 if (!IsSelect(*node)) return false;
2530 const std::vector<OpInfo::TensorProperties>& input_props =
2531 properties.GetInputProperties(node->name());
2532 if (input_props.size() < 3) return false;
2533 const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2534 const bool is_all_true = IsOnes(*predicate_node);
2535 const bool is_all_false = IsZeros(*predicate_node);
2536 if (!is_all_true && !is_all_false) {
2537 return false;
2538 }
2539 const int live_input_idx = is_all_true ? 1 : 2;
2540 const int ignored_input_idx = is_all_true ? 2 : 1;
2541 const TensorShapeProto& predicate_shape = input_props[0].shape();
2542 const bool predicate_is_scalar =
2543 !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2544 if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2545 (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2546 predicate_is_scalar)) {
2547 // Replace node with Identity if no broadcasting is involved.
2548 node->set_op("Identity");
2549 *node->mutable_input(0) =
2550 AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2551 *node->mutable_input(ignored_input_idx) = AddControlDependency(
2552 node->input(ignored_input_idx), optimized_graph, node_map_.get());
2553 node->mutable_input()->SwapElements(0, live_input_idx);
2554 } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2555 optimized_graph)) {
2556 return false;
2557 }
2558 DedupControlInputs(node);
2559 return true;
2560 }
2561
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2562 void ConstantFolding::RemoveRedundantVariableUpdates(
2563 GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2564 static const absl::flat_hash_set<string>* kVariableReadOps =
2565 new absl::flat_hash_set<string>{"AssignAddVariableOp",
2566 "AssignSubVariableOp",
2567 "AssignAdd",
2568 "AssignSub",
2569 "ScatterAdd",
2570 "ScatterSub",
2571 "ScatterMul",
2572 "ScatterDiv",
2573 "ScatterNdAdd",
2574 "ScatterNdSub",
2575 "ScatterNdMul",
2576 "ScatterNdDiv",
2577 "ResourceScatterAdd",
2578 "ResourceScatterSub",
2579 "ResourceScatterMul",
2580 "ResourceScatterDiv",
2581 "ResourceScatterNdAdd",
2582 "ResourceScatterNdSub",
2583 "ResourceScatterNdMul",
2584 "ResourceScatterNdDiv"};
2585 if (kVariableReadOps == nullptr ||
2586 kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2587 return;
2588 const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2589 const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2590 if (delta_node == nullptr) return;
2591 const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2592 absl::StrContains(node->op(), "Sub");
2593 if ((is_add_or_sub && IsZeros(*delta_node)) ||
2594 (!is_add_or_sub && IsOnes(*delta_node))) {
2595 VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2596 if (absl::StrContains(node->op(), "Variable") ||
2597 absl::StrContains(node->op(), "Resource")) {
2598 ReplaceOperationWithNoOp(node, properties, optimized_graph);
2599 } else {
2600 ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2601 optimized_graph);
2602 }
2603 }
2604 }
2605
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2606 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2607 NodeDef* node) {
2608 if (!IsEnter(*node) || node->input_size() == 0 ||
2609 node->attr().count("is_constant") == 0 ||
2610 !node->attr().at("is_constant").b()) {
2611 return false;
2612 }
2613 const string& node_name = node->name();
2614 const NodeDef* input = node_map_->GetNode(node->input(0));
2615 if (input == nullptr || !IsReallyConstant(*input) ||
2616 OptimizedNodeExists(*input, "_enter")) {
2617 return false;
2618 }
2619 // Find non-constant nodes that consume the output of *node.
2620 std::vector<NodeDef*> consumers;
2621 for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2622 if (!IsConstant(*fanout)) {
2623 for (int i = 0; i < fanout->input_size(); ++i) {
2624 if (fanout->input(i) == node_name) {
2625 consumers.push_back(const_cast<NodeDef*>(fanout));
2626 break;
2627 }
2628 }
2629 }
2630 }
2631 if (consumers.empty()) {
2632 return false;
2633 }
2634 graph_modified_ = true;
2635 NodeDef* new_node = optimized_graph->add_node();
2636 *new_node = *input;
2637 new_node->set_name(OptimizedNodeName(*input, "_enter"));
2638 new_node->set_device(node->device());
2639 new_node->clear_input();
2640 new_node->add_input(AsControlDependency(node_name));
2641 node_map_->AddNode(new_node->name(), new_node);
2642 node_map_->AddOutput(node_name, new_node->name());
2643 for (NodeDef* consumer : consumers) {
2644 for (int i = 0; i < consumer->input_size(); ++i) {
2645 if (NodeName(consumer->input(i)) == node_name) {
2646 node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2647 consumer->set_input(i, new_node->name());
2648 }
2649 }
2650 }
2651 return true;
2652 }
2653
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2654 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2655 if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2656 !OptimizedNodeExists(*node, "_const_false") &&
2657 !OptimizedNodeExists(*node, "_const_true")) {
2658 bool already_optimized = true;
2659 // If the optimization was already applied, the switch would have exactly
2660 // one Identity node consuming each of its outputs, each without any
2661 // non-control outputs.
2662 const auto& fanouts = node_map_->GetOutputs(node->name());
2663 if (fanouts.size() == 2) {
2664 for (const NodeDef* fanout : fanouts) {
2665 if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2666 HasRegularOutputs(*fanout, *node_map_)) {
2667 already_optimized = false;
2668 break;
2669 }
2670 }
2671 }
2672 Tensor false_t(DT_BOOL, TensorShape({}));
2673 Tensor true_t(DT_BOOL, TensorShape({}));
2674 // Make sure we don't proceed if this switch node was already optimized.
2675 if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2676 SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2677 // Copy the set of consumers of the switch as they will be manipulated
2678 // below.
2679 std::vector<NodeDef*> consumers =
2680 node_map_->GetOutputsOrderedByNodeName(node->name());
2681 // Create constant false & true nodes.
2682 NodeDef tmp_false_node;
2683 tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2684 if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2685 &tmp_false_node)
2686 .ok()) {
2687 return false;
2688 }
2689 tmp_false_node.set_device(node->device());
2690 NodeDef tmp_true_node;
2691 tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2692 if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2693 &tmp_true_node)
2694 .ok()) {
2695 return false;
2696 }
2697 tmp_true_node.set_device(node->device());
2698
2699 // Add const nodes to graph.
2700 NodeDef* false_node = optimized_graph->add_node();
2701 false_node->Swap(&tmp_false_node);
2702 NodeDef* true_node = optimized_graph->add_node();
2703 true_node->Swap(&tmp_true_node);
2704
2705 // Add controls from the switch ports to the constants, and connect the
2706 // constants to the original switch outputs.
2707 const string false_port = node->name();
2708 const string true_port = strings::StrCat(node->name(), ":1");
2709 const string false_ctrl_dep =
2710 AddControlDependency(false_port, optimized_graph, node_map_.get());
2711 false_node->add_input(false_ctrl_dep);
2712 const string true_ctrl_dep =
2713 AddControlDependency(true_port, optimized_graph, node_map_.get());
2714 true_node->add_input(true_ctrl_dep);
2715
2716 node_map_->AddNode(false_node->name(), false_node);
2717 node_map_->AddNode(true_node->name(), true_node);
2718 node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2719 node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2720
2721 for (NodeDef* consumer : consumers) {
2722 for (int i = 0; i < consumer->input_size(); ++i) {
2723 const string& input = consumer->input(i);
2724 if (input == false_port) {
2725 consumer->set_input(i, false_node->name());
2726 node_map_->UpdateInput(consumer->name(), false_port,
2727 false_node->name());
2728 } else if (input == true_port) {
2729 consumer->set_input(i, true_node->name());
2730 node_map_->UpdateInput(consumer->name(), true_port,
2731 true_node->name());
2732 }
2733 }
2734 }
2735 return true;
2736 }
2737 }
2738 return false;
2739 }
2740
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2741 bool ConstantFolding::IsReductionWithConstantIndices(
2742 const NodeDef& node, bool* indices_is_empty) const {
2743 // Ensure its an appropriate Reduce node.
2744 if (!IsReduction(node) || node.input_size() < 2) {
2745 return false;
2746 }
2747 // Ensure that the axes to reduce by are constant.
2748 NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2749 if (!IsReallyConstant(*reductions_indices) ||
2750 !reductions_indices->attr().count("value")) {
2751 return false;
2752 }
2753 const TensorShapeProto& reduction_indices_shape =
2754 reductions_indices->attr().at("value").tensor().tensor_shape();
2755 *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2756 return true;
2757 }
2758
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2759 bool ConstantFolding::IsReductionCandidateForSimplification(
2760 const NodeDef& node, const GraphProperties& properties,
2761 TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2762 bool* is_single_element_op) const {
2763 // Get the properties of the input & output tensors and check if they both
2764 // contain a single element.
2765 if (!properties.HasInputProperties(node.name()) ||
2766 !properties.HasOutputProperties(node.name())) {
2767 return false;
2768 }
2769 const auto& input_props = properties.GetInputProperties(node.name())[0];
2770 const auto& output_props = properties.GetOutputProperties(node.name())[0];
2771 if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2772 !output_props.has_shape() || output_props.shape().unknown_rank()) {
2773 return false;
2774 }
2775 *input_tensor_shape = input_props.shape();
2776 *output_tensor_shape = output_props.shape();
2777 for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2778 if (input_tensor_shape->dim(i).size() < 0) {
2779 return false;
2780 }
2781 }
2782 for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2783 if (output_tensor_shape->dim(i).size() < 0) {
2784 return false;
2785 }
2786 }
2787 const int input_num_elements =
2788 TensorShape(*input_tensor_shape).num_elements();
2789 const int output_num_elements =
2790 TensorShape(*output_tensor_shape).num_elements();
2791 *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2792
2793 return true;
2794 }
2795
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2796 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2797 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2798 const TensorVector& reduction_indices_vector) const {
2799 int output_size = reduction_indices_vector[0]->NumElements();
2800 if (output_size == 0) {
2801 return true;
2802 }
2803
2804 if (!keep_dims) {
2805 return false;
2806 }
2807 bool simplifiable = true;
2808 for (int i = 0; i < output_size; ++i) {
2809 int64_t dim;
2810 if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2811 dim = reduction_indices_vector[0]->flat<int32>()(i);
2812 } else {
2813 dim = reduction_indices_vector[0]->flat<int64>()(i);
2814 }
2815 if (dim < 0) {
2816 dim += input_shape.dim_size();
2817 }
2818 if (dim < 0 || dim >= input_shape.dim_size() ||
2819 input_shape.dim(dim).size() != 1) {
2820 simplifiable = false;
2821 break;
2822 }
2823 }
2824 return simplifiable;
2825 }
2826
ReplaceReductionWithIdentity(NodeDef * node) const2827 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2828 // Replace the reduction node with an identity node, that can be further
2829 // optimized by other passes.
2830 DataType output_type;
2831 if (node->attr().count("T") != 0) {
2832 output_type = node->attr().at("T").type();
2833 } else if (IsAny(*node) || IsAll(*node)) {
2834 output_type = DT_BOOL;
2835 } else {
2836 return false;
2837 }
2838 node->set_op("Identity");
2839 EraseRegularNodeAttributes(node);
2840 (*node->mutable_attr())["T"].set_type(output_type);
2841 *node->mutable_input(1) = AsControlDependency(node->input(1));
2842 return true;
2843 }
2844
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2845 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2846 const GraphProperties& properties,
2847 NodeDef* node) {
2848 bool indices_is_empty = false;
2849 if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2850 return false;
2851 }
2852 if (indices_is_empty) {
2853 return ReplaceReductionWithIdentity(node);
2854 }
2855 bool is_single_element_op = false;
2856 TensorShapeProto input_tensor_shape, output_tensor_shape;
2857 if (!IsReductionCandidateForSimplification(
2858 *node, properties, &input_tensor_shape, &output_tensor_shape,
2859 &is_single_element_op)) {
2860 return false;
2861 }
2862
2863 // Get the reduction indices.
2864 string reduction_indices_input = node->input(1);
2865 NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2866 TensorVector reduction_indices_vector;
2867 auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2868 for (const auto& out : reduction_indices_vector) {
2869 delete out.tensor;
2870 }
2871 });
2872 if (!EvaluateNode(*reduction_indices, TensorVector(),
2873 &reduction_indices_vector)
2874 .ok() ||
2875 reduction_indices_vector.size() != 1) {
2876 return false;
2877 }
2878
2879 bool keep_dims =
2880 node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2881 bool simplifiable_to_reshape =
2882 is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2883 bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2884 *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2885
2886 if (simplifiable_to_reshape) {
2887 // Const node to output shape.
2888 const int new_num_dimensions = output_tensor_shape.dim_size();
2889 Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2890 for (int i = 0; i < new_num_dimensions; i++) {
2891 tensor.flat<int>()(i) = 1;
2892 }
2893 TensorValue shape_value(&tensor);
2894 NodeDef* shape_node = optimized_graph->add_node();
2895 if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2896 shape_node)
2897 .ok()) {
2898 return false;
2899 }
2900 shape_node->set_device(node->device());
2901 node_map_->AddNode(shape_node->name(), shape_node);
2902 // Control dependency to ensure shape_node is in the correct frame.
2903 shape_node->add_input(AsControlDependency(reduction_indices_input));
2904 node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2905 // Optimize node to Reshape.
2906 node->set_op("Reshape");
2907 node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2908 node->set_input(1, shape_node->name());
2909 node->mutable_attr()->erase("keep_dims");
2910 node->mutable_attr()->erase("Tidx");
2911 AttrValue attr_type_indices;
2912 attr_type_indices.set_type(DT_INT32);
2913 (*node->mutable_attr())["Tshape"] = attr_type_indices;
2914 return true;
2915 } else if (simplifiable_to_identity) {
2916 return ReplaceReductionWithIdentity(node);
2917 }
2918 return false;
2919 }
2920
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2921 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2922 bool use_shape_info, NodeDef* node) {
2923 if (!use_shape_info || node->attr().count("T") == 0 ||
2924 !IsSimplifiableReshape(*node, properties)) {
2925 return false;
2926 }
2927 DataType output_type = node->attr().at("T").type();
2928 node->set_op("Identity");
2929 EraseRegularNodeAttributes(node);
2930 (*node->mutable_attr())["T"].set_type(output_type);
2931 *node->mutable_input(1) = AsControlDependency(node->input(1));
2932 return true;
2933 }
2934
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2935 Status ConstantFolding::SimplifyArithmeticOperations(
2936 const GraphProperties& properties, bool use_shape_info,
2937 GraphDef* optimized_graph, NodeDef* node) {
2938 const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2939 const bool is_matmul = IsAnyMatMul(*node);
2940 const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2941 const bool is_sub = IsSub(*node);
2942 const bool is_any_div = IsAnyDiv(*node) && !IsFloorDiv(*node);
2943 // Simplify arithmetic operations with ones or zeros.
2944 if (use_shape_info &&
2945 (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2946 properties.HasInputProperties(node->name()) &&
2947 properties.HasOutputProperties(node->name())) {
2948 const NodeDef* x = node_map_->GetNode(node->input(0));
2949 const NodeDef* y = node_map_->GetNode(node->input(1));
2950 if (x == nullptr || y == nullptr) {
2951 return errors::InvalidArgument("Invalid inputs to node: ",
2952 node->DebugString());
2953 }
2954 const TensorShapeProto& output_shape =
2955 properties.GetOutputProperties(node->name())[0].shape();
2956
2957 // Simplify element-wise multiplication by ones or addition/subtraction
2958 // of zeros.
2959 const TensorShapeProto& y_shape =
2960 properties.GetInputProperties(node->name())[1].shape();
2961 const TensorShapeProto& x_shape =
2962 properties.GetInputProperties(node->name())[0].shape();
2963 const bool y_matches_output_shape =
2964 ShapesSymbolicallyEqual(output_shape, y_shape);
2965 const bool x_matches_output_shape =
2966 ShapesSymbolicallyEqual(output_shape, x_shape);
2967
2968 const bool x_is_zero = IsZeros(*x);
2969 const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2970 if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2971 // 1 * y = y or 0 + y = y.
2972 if (y_matches_output_shape) {
2973 ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2974 } else if (x_matches_output_shape) {
2975 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
2976 optimized_graph);
2977 }
2978 return Status::OK();
2979 }
2980
2981 if (y_matches_output_shape && (is_sub && x_is_zero)) {
2982 // Replace 0 - y with Neg(y).
2983 ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2984 return Status::OK();
2985 }
2986
2987 // Replace 1 / y with Reciprocal op.
2988 if (y_matches_output_shape && is_any_div && x_is_one) {
2989 TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2990 DataType type = node->attr().at("T").type();
2991 if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2992 ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2993 return Status::OK();
2994 }
2995 }
2996
2997 const bool y_is_zero = IsZeros(*y);
2998 const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2999 if (((is_mul || is_any_div) && y_is_one) ||
3000 ((is_add || is_sub) && y_is_zero)) {
3001 // x * 1 = x or x / 1 = x or x +/- 0 = x
3002 if (x_matches_output_shape) {
3003 ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
3004 } else if (y_matches_output_shape) {
3005 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3006 optimized_graph);
3007 }
3008 return Status::OK();
3009 }
3010
3011 // x OR true = true OR y = true.
3012 const PartialTensorShape shp(output_shape);
3013 if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
3014 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3015 1, properties, output_shape, node, optimized_graph));
3016 return Status::OK();
3017 }
3018
3019 // Simplify multiplication and matmul by zeros.
3020 // Also optimize zeros divided by a tensor, but only if we are in
3021 // aggressive mode, since we might get rid of divisions by zero.
3022 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3023 bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3024 if ((x_is_zero || y_is_zero) &&
3025 (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3026 if (shp.IsFullyDefined()) {
3027 bool is_quantized = IsQuantizedMatMul(*node);
3028 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3029 0, properties, output_shape, node, optimized_graph));
3030 if (is_quantized && graph_modified_) {
3031 TF_RETURN_IF_ERROR(
3032 AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3033 }
3034 return Status::OK();
3035 }
3036 // Even if an input shape is only partially known, we may known that it
3037 // matches the output shape and thus forward or broadcast the
3038 // corresponding zero input.
3039 if ((is_mul || is_any_div) && x_is_zero) {
3040 if (x_matches_output_shape) {
3041 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3042 } else if (y_matches_output_shape) {
3043 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3044 optimized_graph);
3045 }
3046 return Status::OK();
3047 } else if (is_mul && y_is_zero) {
3048 if (y_matches_output_shape) {
3049 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3050 } else if (x_matches_output_shape) {
3051 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3052 optimized_graph);
3053 }
3054 return Status::OK();
3055 }
3056 }
3057 }
3058 return Status::OK();
3059 }
3060
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3061 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3062 NodeDef* node) {
3063 // Strength reduce floating point division by a constant Div(x, const) to
3064 // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3065 // will be constant folded to Mul(x, 1.0/const).
3066 if (node->input_size() >= 2 &&
3067 (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3068 const string& const_input = node->input(1);
3069 const NodeDef* denom = node_map_->GetNode(const_input);
3070 CHECK(denom != nullptr);
3071 if (!IsReallyConstant(*denom)) {
3072 return false;
3073 }
3074 if (node->attr().count("T") == 0) {
3075 return false;
3076 }
3077 DataType type = node->attr().at("T").type();
3078 // Skip integer division.
3079 if (IsDiv(*node) &&
3080 !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3081 return false;
3082 }
3083 // Insert new reciprocal op and change node from Div to Mul.
3084 NodeDef* reciprocal_node = optimized_graph->add_node();
3085 reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3086 reciprocal_node->set_op("Reciprocal");
3087 reciprocal_node->set_device(node->device());
3088 reciprocal_node->add_input(const_input);
3089 (*reciprocal_node->mutable_attr())["T"].set_type(type);
3090
3091 // Re-wire inputs and outputs.
3092 if (IsXdivy(*node)) {
3093 node->set_op("MulNoNan");
3094 node->set_input(1, node->input(0));
3095 node->set_input(0, reciprocal_node->name());
3096 } else {
3097 node->set_op("Mul");
3098 node->set_input(1, reciprocal_node->name());
3099 }
3100 node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3101 node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3102
3103 return true;
3104 }
3105
3106 return false;
3107 }
3108
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3109 bool ConstantFolding::PrepareConstantPushDown(
3110 const NodeDef& parent, const GraphProperties& properties,
3111 bool must_have_properties, ConstantPushDownContext* ctx) const {
3112 if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3113 return false;
3114 }
3115 NodeDef* left_child = node_map_->GetNode(parent.input(0));
3116 NodeDef* right_child = node_map_->GetNode(parent.input(1));
3117
3118 // Sanity check for missing children.
3119 if (left_child == nullptr || right_child == nullptr) {
3120 return false;
3121 }
3122
3123 ctx->left_child_is_const = IsReallyConstant(*left_child);
3124 ctx->right_child_is_const = IsReallyConstant(*right_child);
3125 ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3126 ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3127
3128 // Nothing to do unless the parent has a constant child node.
3129 if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3130 return false;
3131 }
3132
3133 // Don't move nodes across devices.
3134 if (parent.device() != ctx->op_child->device() ||
3135 parent.device() != ctx->const_child->device()) {
3136 return false;
3137 }
3138
3139 // Make sure that it is safe to change the value of the child node result.
3140 if (ctx->op_child->input_size() < 2 ||
3141 nodes_to_preserve_.find(ctx->op_child->name()) !=
3142 nodes_to_preserve_.end() ||
3143 NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3144 return false;
3145 }
3146
3147 // Don't apply reassociation to floating point types of low precision.
3148 // The danger of significant numerical changes is too high.
3149 if (!CheckAttrExists(parent, "T").ok()) return false;
3150 DataType dtype = parent.attr().at("T").type();
3151 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3152 return false;
3153 }
3154
3155 // Don't rewrite the tree if it might create cycles.
3156 // TODO(rmlarsen): Add back handling of control dependency from op to C.
3157 const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3158 if (child_output.find(ctx->const_child) != child_output.end()) {
3159 return false;
3160 }
3161
3162 // Get leaf nodes.
3163 ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3164 ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3165 ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3166 ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3167
3168 if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3169 // Child is already foldable, leave it alone.
3170 return false;
3171 }
3172
3173 // Don't move nodes across devices.
3174 if (parent.device() != ctx->left_leaf->device() ||
3175 parent.device() != ctx->right_leaf->device()) {
3176 return false;
3177 }
3178
3179 // Get shape and type information.
3180 ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3181 ctx->op_child_input_props =
3182 &properties.GetInputProperties(ctx->op_child->name());
3183 if (must_have_properties && (ctx->parent_input_props == nullptr ||
3184 ctx->parent_input_props->size() < 2 ||
3185 ctx->op_child_input_props == nullptr ||
3186 ctx->op_child_input_props->size() < 2)) {
3187 return false;
3188 }
3189
3190 VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3191 << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3192 << ")";
3193
3194 return true;
3195 }
3196
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3197 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3198 GraphDef* optimized_graph,
3199 NodeDef* node) {
3200 // This implements constant push-down for BiasAdd. In the following "CV" is a
3201 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3202 // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3203 // non-constant matrix, and "BA" is BiasAdd.
3204 // For a valid input graph, the following 4 rewrites are legal:
3205 //
3206 // 1) + +
3207 // / \ / \
3208 // BA CV -- > BA V
3209 // / \ / \
3210 // M V M CV
3211 //
3212 // 2) + +
3213 // / \ / \
3214 // BA CM -- > BA M
3215 // / \ / \
3216 // M V CM V
3217 //
3218 // 3) BA BA
3219 // / \ / \
3220 // + CV -- > + V
3221 // / \ / \
3222 // M V M CV
3223 //
3224 // 4) BA BA = parent
3225 // / \ / \
3226 // BA CV -- > BA V = children
3227 // / \ / \
3228 // M V M CV = leaves
3229 //
3230 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3231
3232 const bool parent_is_bias_add = IsBiasAdd(*node);
3233 if (!parent_is_bias_add && !IsAdd(*node)) return false;
3234 ConstantPushDownContext ctx;
3235 if (!PrepareConstantPushDown(*node, *properties,
3236 /*must_have_properties=*/true, &ctx)) {
3237 return false;
3238 }
3239 // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3240 // >= 2 and the leaves must be vectors, we cannot swap them.
3241 if (ctx.left_child_is_const && parent_is_bias_add) return false;
3242 const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3243 if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3244
3245 // Get properties to validate rank and dtype constraints.
3246 if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3247 (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3248 (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3249 (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3250 (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3251 return false;
3252 }
3253
3254 // Now get the ranks and types of the 3 leaf nodes.
3255 const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3256 const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3257 // At least one leaf must be a vector.
3258 if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3259 const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3260 const int matrix_idx = 1 - vector_idx;
3261
3262 const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3263 const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3264 if (vector_rank != 1) return false; // this should never happen.
3265 const DataType vector_type = vector_prop.dtype();
3266
3267 const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3268 const int matrix_rank = matrix_prop.shape().dim_size();
3269 const DataType matrix_type = matrix_prop.dtype();
3270
3271 const int const_idx = ctx.left_child_is_const ? 0 : 1;
3272 const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3273 const int const_rank = const_prop.shape().dim_size();
3274 const DataType const_type = const_prop.dtype();
3275
3276 int input_to_swap = -1;
3277
3278 if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3279 const_type == matrix_type) {
3280 // Case 2:
3281 input_to_swap = matrix_idx;
3282 } else if (const_rank == 1 && const_type == vector_type) {
3283 // Case 1, 3, and, 4:
3284 input_to_swap = vector_idx;
3285 }
3286 if (input_to_swap == -1) return false;
3287 const NodeDef* leaf_to_swap =
3288 node_map_->GetNode(ctx.op_child->input(input_to_swap));
3289 if (IsConstant(*leaf_to_swap)) return false;
3290
3291 node_map_->UpdateInput(node->name(), node->input(const_idx),
3292 ctx.op_child->input(input_to_swap));
3293 node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3294 if (ctx.op_child->input(input_to_swap) !=
3295 ctx.op_child->input(1 - input_to_swap)) {
3296 node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3297 ctx.op_child->name());
3298 }
3299 std::swap(*node->mutable_input(const_idx),
3300 *ctx.op_child->mutable_input(input_to_swap));
3301 properties->ClearInputProperties(node->name());
3302 properties->ClearInputProperties(ctx.op_child->name());
3303
3304 return true;
3305 }
3306
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3307 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3308 GraphDef* optimized_graph,
3309 NodeDef* node) {
3310 // Consider the transformation
3311 //
3312 // + + = parent
3313 // / \ / \
3314 // C + -- > X + = children
3315 // / \ / \
3316 // X Y C Y = leaves
3317 //
3318 // where C is constant, X is non-constant, Y may be constant or non-constant,
3319 // and '+' denotes an associative and commutative operator like addition or
3320 // multiplication. This optimization pushes constants down in the tree to
3321 // canonicalize it. Moreover, in cases where the child node has a second
3322 // constant input Y we will create a leaf node that can be folded, e.g.
3323 //
3324 // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3325 //
3326 // We also handle the non-commutative cases of subtraction and division
3327 // by rotating the tree locally, e.g.
3328 // Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3329 // Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3330
3331 // Get parent op type.
3332 const bool is_add = IsAdd(*node);
3333 const bool is_mul = IsMul(*node);
3334 const bool is_sub = IsSub(*node);
3335 const bool is_div = IsDiv(*node);
3336 if (!(is_add || is_sub || is_mul || is_div)) return false;
3337 const bool is_symmetric = is_add || is_mul;
3338
3339 ConstantPushDownContext ctx;
3340 if (!PrepareConstantPushDown(*node, *properties,
3341 /*must_have_properties=*/false, &ctx)) {
3342 return false;
3343 }
3344
3345 // Get child op type.
3346 const bool is_child_add = IsAdd(*ctx.op_child);
3347 const bool is_child_mul = IsMul(*ctx.op_child);
3348 const bool is_child_sub = IsSub(*ctx.op_child);
3349 const bool is_child_div = IsDiv(*ctx.op_child);
3350 const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3351 const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3352 if (!is_add_sub && !is_mul_div) {
3353 return false;
3354 }
3355 const bool is_child_symmetric = is_child_add || is_child_mul;
3356
3357 if (!CheckAttrExists(*node, "T").ok()) return false;
3358 DataType dtype = node->attr().at("T").type();
3359 if (!(is_symmetric && is_child_symmetric) &&
3360 !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3361 return false;
3362 }
3363
3364 const NodeDef* y_node =
3365 ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3366 if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3367 !ctx.op_child_input_props->empty()) {
3368 // If we know the shapes of the nodes being swapped, make sure we don't push
3369 // down a larger node and create more work by broadcasting earlier in the
3370 // expressions tree.
3371 const PartialTensorShape c_shape(
3372 (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3373 const PartialTensorShape x_shape(
3374 (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3375
3376 if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3377 c_shape.num_elements() > x_shape.num_elements()) {
3378 return false;
3379 } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3380 c_shape.dims() > 0) {
3381 for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3382 if (x_shape.dim_size(idx) >= 0 &&
3383 c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3384 return false;
3385 }
3386 }
3387 }
3388 }
3389
3390 // Get the node names corresponding to X, Y, and C.
3391 const string input_x =
3392 ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3393 const string input_y = input_x == ctx.op_child->input(0)
3394 ? ctx.op_child->input(1)
3395 : ctx.op_child->input(0);
3396 const string input_c =
3397 ctx.left_child_is_const ? node->input(0) : node->input(1);
3398 const string input_op =
3399 ctx.left_child_is_const ? node->input(1) : node->input(0);
3400 VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3401
3402 // Now we have identified the nodes to swap, update the nodemap accordingly.
3403 node_map_->UpdateInput(node->name(), input_c, input_x);
3404 node_map_->AddOutput(input_c, ctx.op_child->name());
3405 if (input_x != input_y) {
3406 node_map_->RemoveOutput(input_x, ctx.op_child->name());
3407 }
3408 properties->ClearInputProperties(node->name());
3409 properties->ClearInputProperties(ctx.op_child->name());
3410
3411 if (is_symmetric && is_child_symmetric) {
3412 // Easy case (only commutative ops). We always write this as one of
3413 // +
3414 // / \
3415 // X +
3416 // / \
3417 // C Y
3418 node->set_input(0, input_x);
3419 node->set_input(1, input_op);
3420 ctx.op_child->set_input(0, input_c);
3421 ctx.op_child->set_input(1, input_y);
3422 } else {
3423 // More complicated case: When there are non-commutative operations like
3424 // subtractions or divisions involved, we may have to rotate the tree
3425 // and/or change op types. There are 6 non-trivial cases depending on
3426 // the effective generalized "sign" of each of the three terms C, Y, and X.
3427 // Here are the final trees we want to generate for those 6 cases:
3428 //
3429 // (CYX signs): ++- +-- -+- --+ +-+ -++
3430 //
3431 // - - - - + +
3432 // / \ / \ / \ / \ / \ / \
3433 // + X - X - X X + X - X -
3434 // / \ / \ / \ / \ / \ / \
3435 // C Y C Y Y C Y C C Y Y C
3436 //
3437
3438 // First, let's determine the effective sign of each term in the original
3439 // expression
3440 auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3441 bool leaf_negated = !is_child_symmetric && is_right_leaf;
3442 bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3443 return leaf_negated != child_negated;
3444 };
3445 const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3446 const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3447 bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3448 bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3449 bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3450 // Rewrite the parent node.
3451 node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3452 node->set_input(0, neg_x ? input_op : input_x);
3453 node->set_input(1, neg_x ? input_x : input_op);
3454 // Rewrite the child node.
3455 ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3456 ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3457 ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3458 }
3459 return true;
3460 }
3461
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3462 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3463 const GraphProperties& properties) {
3464 // Push down multiplication on ConvND.
3465 // * ConvND
3466 // / \ / \
3467 // ConvND C2 -- > X *
3468 // / \ / \
3469 // X C1 C1 C2
3470 //
3471 // where C1 and C2 are constants and X is non-constant.
3472 //
3473 // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3474
3475 if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3476
3477 NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3478 NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3479 // One child must be constant, and the second must be Conv op.
3480 const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3481 const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3482 if (!left_child_is_constant && !right_child_is_constant) {
3483 return false;
3484 }
3485 NodeDef* conv_node =
3486 left_child_is_constant ? mul_right_child : mul_left_child;
3487 if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3488 return false;
3489 }
3490 if (node->device() != mul_left_child->device() ||
3491 node->device() != mul_right_child->device()) {
3492 return false;
3493 }
3494
3495 // Make sure that it is safe to change the value of the convolution
3496 // output.
3497 if (conv_node->input_size() < 2 ||
3498 NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3499 nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3500 return false;
3501 }
3502
3503 // Identify the nodes to swap.
3504 NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3505 NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3506 const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3507 const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3508 if (!conv_left_is_constant && !conv_right_is_constant) {
3509 // At least one of the convolution inputs should be constant.
3510 return false;
3511 }
3512 if (conv_left_is_constant && conv_right_is_constant) {
3513 // Leverage regular constant folding to handle this.
3514 return false;
3515 }
3516 const auto& mul_props = properties.GetOutputProperties(node->name());
3517 const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3518 if (mul_props.empty() || conv_props.empty()) {
3519 return false;
3520 }
3521 const auto& mul_shape = mul_props[0].shape();
3522 const auto& conv_shape = conv_props[0].shape();
3523 if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3524 return false;
3525 }
3526
3527 const auto& input_props = properties.GetInputProperties(conv_node->name());
3528 if (input_props.size() < 2) {
3529 return false;
3530 }
3531 const auto& filter_shape = input_props[1].shape();
3532
3533 NodeDef* const_node =
3534 left_child_is_constant ? mul_left_child : mul_right_child;
3535 const auto& const_props = properties.GetOutputProperties(const_node->name());
3536 if (const_props.empty()) {
3537 return false;
3538 }
3539 const auto& const_shape = const_props[0].shape();
3540 if (!IsValidConstShapeForMulConvPushDown(
3541 conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3542 return false;
3543 }
3544
3545 string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3546 if (node_map_->NodeExists(mul_new_name)) {
3547 return false;
3548 }
3549 // Make sure we don't introduce loops in the graph by removing control
3550 // dependencies from the conv2d node to c2.
3551 string conv_const_input =
3552 conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3553 if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3554 node_map_.get())) {
3555 // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3556 MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3557 node_map_.get());
3558 }
3559
3560 conv_node->set_name(node->name());
3561 node->set_name(mul_new_name);
3562 if (conv_left_is_constant) {
3563 node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3564 conv_node->set_input(0, mul_new_name);
3565 } else {
3566 node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3567 conv_node->set_input(1, mul_new_name);
3568 }
3569 NodeDef* conv_const_node =
3570 conv_left_is_constant ? conv_left_child : conv_right_child;
3571 if (left_child_is_constant) {
3572 node->set_input(1, conv_const_node->name());
3573 } else {
3574 node->set_input(0, conv_const_node->name());
3575 }
3576 node_map_->AddNode(mul_new_name, node);
3577
3578 return true;
3579 }
3580
PartialConstPropThroughIdentityN(NodeDef * node)3581 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3582 // Partial constant propagation through IdentityN.
3583 if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3584 !HasRegularInputs(*node))
3585 return false;
3586
3587 std::vector<int> inputs_to_forward;
3588 for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3589 const string& input = node->input(input_idx);
3590 if (IsControlInput(input)) {
3591 return false;
3592 }
3593 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3594 if (input_node == nullptr) {
3595 LOG(ERROR) << "Bad input: " << input;
3596 return false;
3597 }
3598 // Forward constant inputs to outputs and add a control dependency on
3599 // the IdentityN node.
3600 if (IsReallyConstant(*input_node)) {
3601 inputs_to_forward.push_back(input_idx);
3602 }
3603 }
3604 return ForwardInputs(node, inputs_to_forward);
3605 }
3606
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3607 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3608 GraphProperties* properties,
3609 NodeDef* node) {
3610 // Partial constant folding for associative operators:
3611 // Split AddN/AccumulateNV2 to enable partial
3612 // folding of ops when more than one but not all inputs are constant.
3613 // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3614 // addition is commutative.
3615 if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3616
3617 const int num_non_control_inputs = NumNonControlInputs(*node);
3618 if (num_non_control_inputs <= 2) return false;
3619 const int num_control_inputs = node->input_size() - num_non_control_inputs;
3620 std::vector<int> const_inputs;
3621 std::vector<int> nonconst_inputs;
3622 for (int i = 0; i < node->input_size(); ++i) {
3623 const string& input = node->input(i);
3624 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3625 if (input_node == nullptr) return false;
3626 if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3627 const_inputs.push_back(i);
3628 } else {
3629 // Non-const and control inputs.
3630 nonconst_inputs.push_back(i);
3631 }
3632 }
3633 // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3634 // a fake node that cannot be constant folded by itself.
3635 int const_inputs_size = const_inputs.size();
3636 if (const_inputs_size == num_non_control_inputs &&
3637 node->op() == "AccumulateNV2") {
3638 node->set_op("AddN");
3639 node->mutable_attr()->erase("shape");
3640 return true;
3641 }
3642 const string new_node_name = OptimizedNodeName(
3643 *node, strings::StrCat("_partial_split_", const_inputs_size));
3644 if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3645 !node_map_->NodeExists(new_node_name)) {
3646 NodeDef* added_node = optimized_graph->add_node();
3647 *added_node = *node;
3648 // Always use AddN for the constant node, since AccumulateNV2 is a fake
3649 // node that cannot be constant folded, since it does not have a kernel.
3650 added_node->set_op("AddN");
3651 added_node->mutable_attr()->erase("shape");
3652 added_node->set_name(new_node_name);
3653 node_map_->AddNode(added_node->name(), added_node);
3654 added_node->clear_input();
3655 for (int i : const_inputs) {
3656 added_node->add_input(node->input(i));
3657 node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3658 added_node->name());
3659 }
3660
3661 // Overwrite the first const input with the added node.
3662 node->set_input(const_inputs[0], added_node->name());
3663 node_map_->AddOutput(added_node->name(), node->name());
3664 nonconst_inputs.push_back(const_inputs[0]);
3665 // Compact the remaining inputs to the original node.
3666 std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3667 int idx = 0;
3668 for (int i : nonconst_inputs) {
3669 if (idx != i) {
3670 node->set_input(idx, node->input(i));
3671 }
3672 ++idx;
3673 }
3674 node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3675 const_inputs.size() - 1);
3676 (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3677 properties->ClearInputProperties(node->name());
3678 (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3679 return true;
3680 }
3681 return false;
3682 }
3683
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3684 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3685 GraphProperties* properties,
3686 NodeDef* node) {
3687 // Partial constant folding for Concat which is not commutative, so
3688 // we have to preserve order and can only push consecutive runs of constant
3689 // inputs into sub-nodes.
3690 if (!IsConcat(*node) ||
3691 node->name().rfind("_partial_split_") != string::npos) {
3692 return false;
3693 }
3694 const int num_non_control_inputs = NumNonControlInputs(*node);
3695 if (num_non_control_inputs <= 3) return false;
3696 int axis_arg = -1;
3697 int begin = 0;
3698 int end = num_non_control_inputs;
3699 if (node->op() == "Concat") {
3700 begin = 1;
3701 axis_arg = 0;
3702 } else if (node->op() == "ConcatV2") {
3703 end = num_non_control_inputs - 1;
3704 axis_arg = num_non_control_inputs - 1;
3705 } else {
3706 return false;
3707 }
3708
3709 // We search for consecutive runs of constant inputs in the range
3710 // [begin:end[ and push then down into child nodes.
3711 std::vector<std::pair<int, int>> constant_input_runs;
3712 int first = begin;
3713 int last = begin;
3714 while (last < end) {
3715 while (first < end && !IsReallyConstant(*node_map_->GetNode(
3716 NodeName(node->input(first))))) {
3717 ++first;
3718 }
3719 // Invariant: node[first] is constant || first >= end.
3720 last = first + 1;
3721 while (last < end &&
3722 IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3723 ++last;
3724 }
3725 // Invariant: node[last] is not constant || last >= end
3726 // Discard intervals shorter than 2 elements.
3727 if (first < end && (last - first) > 1) {
3728 constant_input_runs.emplace_back(first, last);
3729 }
3730 first = last;
3731 }
3732
3733 // Skip if all inputs are constant, and let constant folding take over.
3734 if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3735 constant_input_runs[0].first == begin &&
3736 constant_input_runs[0].second == end)) {
3737 return false;
3738 }
3739 std::set<int> inputs_to_delete;
3740 for (auto interval : constant_input_runs) {
3741 // Push the constant inputs in the interval to a child node than can be
3742 // constant folded.
3743 string new_node_name = OptimizedNodeName(*node, "_partial_split");
3744 do {
3745 new_node_name += strings::StrCat("_", interval.first);
3746 } while (node_map_->NodeExists(new_node_name));
3747
3748 NodeDef* added_node = optimized_graph->add_node();
3749 *added_node = *node;
3750 added_node->set_op("ConcatV2");
3751 added_node->set_name(new_node_name);
3752 node_map_->AddNode(added_node->name(), added_node);
3753 added_node->clear_input();
3754 for (int i = interval.first; i < interval.second; ++i) {
3755 added_node->add_input(node->input(i));
3756 node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3757 if (i != interval.first) {
3758 inputs_to_delete.insert(i);
3759 }
3760 }
3761 added_node->add_input(node->input(axis_arg));
3762 (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3763 node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3764
3765 // Overwrite the first constant input with the result of the added
3766 // child node.
3767 node->set_input(interval.first, added_node->name());
3768 }
3769 if (!inputs_to_delete.empty()) {
3770 // Fix up the inputs to the original node.
3771 protobuf::RepeatedPtrField<string> tmp;
3772 tmp.Swap(node->mutable_input());
3773 for (int i = 0; i < tmp.size(); ++i) {
3774 if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3775 node->add_input(tmp.Get(i));
3776 }
3777 }
3778 (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3779 properties->ClearInputProperties(node->name());
3780 }
3781 return true;
3782 }
3783
GetConcatAxis(const NodeDef & node,int * axis)3784 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3785 if (node.op() != "ConcatV2") {
3786 return false;
3787 }
3788 int axis_idx = node.input_size() - 1;
3789 while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3790 --axis_idx;
3791 }
3792 if (axis_idx <= 0) {
3793 return false;
3794 }
3795 Tensor axis_tensor;
3796 if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3797 return false;
3798 }
3799 *axis = axis_tensor.dtype() == DT_INT64
3800 ? static_cast<int>(axis_tensor.scalar<int64>()())
3801 : axis_tensor.scalar<int32>()();
3802 return true;
3803 }
3804
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3805 bool ConstantFolding::MergeConcat(bool use_shape_info,
3806 GraphProperties* properties,
3807 GraphDef* optimized_graph, NodeDef* node) {
3808 // We only optimize for ConcatV2.
3809 int axis;
3810 if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3811 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3812 node_map_->GetOutputs(node->name()).size() != 1) {
3813 return false;
3814 }
3815
3816 // If all inputs are constant, don't merge and let folding take case of it.
3817 const int num_regular_inputs = NumNonControlInputs(*node);
3818 bool all_inputs_are_const = true;
3819 for (int i = 0; i < num_regular_inputs - 1; ++i) {
3820 const NodeDef* input_node = node_map_->GetNode(node->input(i));
3821 if (!IsReallyConstant(*input_node)) {
3822 all_inputs_are_const = false;
3823 break;
3824 }
3825 }
3826 if (all_inputs_are_const) return false;
3827
3828 NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3829 int parent_axis;
3830 if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3831 return false;
3832 }
3833
3834 // Make a pass over the parent inputs to see if any of them have explicit
3835 // device() fields set, and if different inputs are on different tasks. If
3836 // so, this concat of concats may have been carefully constructed to be a
3837 // two-stage concat, and we don't want to undo that here.
3838 string task, device;
3839 absl::flat_hash_set<string> unique_input_tasks;
3840 const int n_parent_inputs = NumNonControlInputs(*parent);
3841 // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1). The
3842 // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3843 // node, which we don't want to consider here.
3844 for (int i = 0; i < n_parent_inputs - 1; ++i) {
3845 const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3846 if (!input_node->device().empty() &&
3847 tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3848 &task, &device)) {
3849 unique_input_tasks.insert(task);
3850 if (unique_input_tasks.size() >= 2) {
3851 // More than one input task represented in the device specifications
3852 // of the parent's input nodes. Don't mess with this.
3853 return false;
3854 }
3855 }
3856 }
3857
3858 protobuf::RepeatedPtrField<string> parent_inputs;
3859 parent_inputs.Swap(parent->mutable_input());
3860 // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3861 // collapse it into the parent multiple times? Probably not.
3862 for (const auto& input : parent_inputs) {
3863 if (IsSameInput(input, node->name())) {
3864 for (int j = 0; j < num_regular_inputs - 1; ++j) {
3865 // Add tensor inputs to first child concat tensors (except the final
3866 // axis input) to the parent's inputs.
3867 parent->add_input(node->input(j));
3868 node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3869 }
3870 } else {
3871 parent->add_input(input);
3872 }
3873 }
3874 // Forward Add control inputs
3875 const int num_inputs = node->input_size();
3876 for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3877 parent->add_input(node->input(i));
3878 node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3879 node->mutable_input()->RemoveLast();
3880 }
3881 (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3882 DedupControlInputs(parent);
3883 ReplaceOperationWithNoOp(node, properties, optimized_graph);
3884
3885 return true;
3886 }
3887
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3888 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3889 NodeDef* node, GraphDef* optimized_graph) {
3890 auto add_quantized_out = [this, node, optimized_graph](
3891 const string& out_const_name, int index) {
3892 NodeDef* out_node = optimized_graph->add_node();
3893 graph_modified_ = true;
3894 Tensor value(DT_FLOAT, TensorShape({}));
3895 const bool is_min = index == 1;
3896 const DataType type_attr = node->attr().at("dtype").type();
3897
3898 value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3899 : QuantizedTypeMaxAsFloat(type_attr);
3900 TF_RETURN_IF_ERROR(
3901 CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3902 node_map_->AddNode(out_const_name, out_node);
3903 out_node->set_device(node->device());
3904 // Copy all inputs from node.
3905 out_node->mutable_input()->CopyFrom(node->input());
3906 for (const string& input : out_node->input()) {
3907 node_map_->AddOutput(NodeName(input), out_const_name);
3908 }
3909
3910 // Update output nodes consuming node:index to new const node.
3911 string old_input = absl::StrCat(node->name(), ":", index);
3912 int old_node_count = 0;
3913 // We make a copy since the set might change.
3914 auto outputs = node_map_->GetOutputs(node->name());
3915 for (const auto& output : outputs) {
3916 for (int i = 0; i < output->input_size(); ++i) {
3917 if (output->input(i) == old_input) {
3918 output->set_input(i, out_const_name);
3919 node_map_->AddOutput(out_const_name, output->name());
3920 } else if (NodeName(output->input(i)) == node->name()) {
3921 ++old_node_count;
3922 }
3923 }
3924 if (old_node_count == 0) {
3925 node_map_->RemoveOutput(node->name(), output->name());
3926 }
3927 }
3928
3929 return Status::OK();
3930 };
3931 const string min_out_const_name =
3932 OptimizedNodeName(*node, "-quantized_matmul_min_out");
3933 const string max_out_const_name =
3934 OptimizedNodeName(*node, "-quantized_matmul_max_out");
3935 if (node_map_->GetNode(min_out_const_name) == nullptr &&
3936 node_map_->GetNode(max_out_const_name) == nullptr) {
3937 TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3938 TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3939 } else {
3940 return errors::Internal(absl::Substitute(
3941 "Can't create Const for QuantizedMatMul min_out/max_out of "
3942 "node '$0' because of node name conflict",
3943 node->name()));
3944 }
3945 return Status::OK();
3946 }
3947
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphDef * optimized_graph)3948 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3949 GrapplerItem* item,
3950 GraphDef* optimized_graph) {
3951 graph_ = &item->graph;
3952 node_map_.reset(new NodeMap(graph_));
3953 nodes_allowlist_.clear();
3954 // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3955 // has a single fanout, it would be rewritten as a constant with the same
3956 // node name, and therefore users are still able to fetch it. This is not
3957 // the case if the node has multiple fanouts, and constant folding would
3958 // replace the node with multiple constants (each for one fanout) with
3959 // new names, and as a result users would not be able to fetch the node any
3960 // more with the original node name.
3961 for (const auto& fetch : item->fetch) {
3962 const NodeDef* fetch_node = node_map_->GetNode(fetch);
3963 if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3964 nodes_allowlist_.insert(fetch_node->name());
3965 }
3966 }
3967
3968 GraphProperties properties(*item);
3969 // It's possible to feed a placeholder with a tensor of any shape: make sure
3970 // that the shape inference deals with this conservatively unless we're in
3971 // aggressive mode.
3972 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3973 Status s = properties.InferStatically(assume_valid_feeds,
3974 /*aggressive_shape_inference=*/false,
3975 /*include_input_tensor_values=*/false,
3976 /*include_output_tensor_values=*/true);
3977
3978 const bool can_use_shape_info = s.ok();
3979 VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
3980
3981 absl::flat_hash_set<string> nodes_to_not_simplify;
3982 if (can_use_shape_info) {
3983 TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3984 TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3985 TF_RETURN_IF_ERROR(
3986 FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
3987 } else {
3988 *optimized_graph = *graph_;
3989 }
3990 node_map_.reset(new NodeMap(optimized_graph));
3991 TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3992 &properties, &nodes_to_not_simplify));
3993
3994 return Status::OK();
3995 }
3996
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3997 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3998 GraphDef* optimized_graph) {
3999 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
4000 // the same here.
4001 port::ScopedFlushDenormal flush;
4002 port::ScopedSetRound round(FE_TONEAREST);
4003 nodes_to_preserve_ = item.NodesToPreserve();
4004 for (const auto& feed : item.feed) {
4005 feed_nodes_.insert(NodeName(feed.first));
4006 }
4007
4008 if (cpu_device_ == nullptr) {
4009 owned_device_.reset(new DeviceSimple());
4010 cpu_device_ = owned_device_.get();
4011 }
4012
4013 graph_contains_assign_or_inplace_op_ = false;
4014 for (const NodeDef& node : item.graph.node()) {
4015 if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
4016 graph_contains_assign_or_inplace_op_ = true;
4017 break;
4018 }
4019 }
4020
4021 has_fetch_ = !item.fetch.empty();
4022 GrapplerItem item_to_optimize = item;
4023 *optimized_graph = GraphDef();
4024 item_to_optimize.graph.Swap(optimized_graph);
4025 int64_t node_count;
4026 do {
4027 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4028 graph_modified_ = false;
4029 item_to_optimize.graph.Swap(optimized_graph);
4030 optimized_graph->Clear();
4031 node_count = item_to_optimize.graph.node_size();
4032 TF_RETURN_IF_ERROR(
4033 RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
4034 } while (graph_modified_ || optimized_graph->node_size() != node_count);
4035 *optimized_graph->mutable_library() = item.graph.library();
4036 *optimized_graph->mutable_versions() = item.graph.versions();
4037
4038 return Status::OK();
4039 }
4040
4041 } // namespace grappler
4042 } // namespace tensorflow
4043