1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19
20 #include <cmath>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61
62 // We only fold/materialize constants smaller than 100kB.
63 const int64 kMaxConstantSize = 100 * 1024;
64
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68 Tensor tensor;
69 if (!tensor.FromProto(proto)) {
70 return false;
71 }
72 auto values = tensor.flat<T>();
73 for (int i = 0; i < tensor.NumElements(); ++i) {
74 if (values(i) != value) {
75 return false;
76 }
77 }
78 return true;
79 }
80
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85 GraphDef* graph, NodeMap* node_map) {
86 bool already_exists = false;
87 for (const string& input : node->input()) {
88 if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89 already_exists = true;
90 break;
91 }
92 }
93 if (!already_exists) {
94 const string ctrl_dep =
95 ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96 node->add_input(ctrl_dep);
97 node_map->AddOutput(NodeName(ctrl_input), node->name());
98 }
99 return !already_exists;
100 }
101
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104 GraphDef* graph, NodeMap* node_map) {
105 bool removed_input = false;
106 bool update_node_map = true;
107 const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108 for (int i = 0; i < node->input_size(); ++i) {
109 const string& input = node->input(i);
110 if (old_input_ctrl_dep == input) {
111 if (IsControlInput(input)) {
112 node->mutable_input()->SwapElements(i, node->input_size() - 1);
113 node->mutable_input()->RemoveLast();
114 removed_input = true;
115 } else {
116 // There is a non-control input from the same node.
117 // Don't remove the output from the NodeMap.
118 update_node_map = false;
119 }
120 }
121 }
122 if (update_node_map) {
123 node_map->RemoveOutput(NodeName(old_input), node->name());
124 }
125 return removed_input;
126 }
127
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129 AttrSlice attrs(node);
130 for (const auto& attr : attrs) {
131 if (attr.first.find("_tpu_") != attr.first.npos) {
132 return true;
133 }
134 }
135 return false;
136 }
137
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140 return a != b;
141 }
142
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154 switch (data_type) {
155 case DT_QINT8:
156 return Eigen::NumTraits<qint8>::lowest();
157 case DT_QUINT8:
158 return Eigen::NumTraits<quint8>::lowest();
159 case DT_QINT16:
160 return Eigen::NumTraits<qint16>::lowest();
161 case DT_QUINT16:
162 return Eigen::NumTraits<quint16>::lowest();
163 case DT_QINT32:
164 return Eigen::NumTraits<qint32>::lowest();
165 default:
166 return 0.0f;
167 }
168 }
169
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171 switch (data_type) {
172 case DT_QINT8:
173 return Eigen::NumTraits<qint8>::highest();
174 case DT_QUINT8:
175 return Eigen::NumTraits<quint8>::highest();
176 case DT_QINT16:
177 return Eigen::NumTraits<qint16>::highest();
178 case DT_QUINT16:
179 return Eigen::NumTraits<quint16>::highest();
180 case DT_QINT32:
181 return Eigen::NumTraits<qint32>::highest();
182 default:
183 return 0.0f;
184 }
185 }
186
187 } // namespace
188
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190 DeviceBase* cpu_device,
191 bool disable_compressed_tensor_optimization,
192 bool fold_quantization_emulation)
193 : opt_level_(opt_level),
194 cpu_device_(cpu_device),
195 disable_compressed_tensor_optimization_(
196 disable_compressed_tensor_optimization),
197 fold_quantization_emulation_(fold_quantization_emulation) {
198 resource_mgr_.reset(new ResourceMgr());
199 }
200
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202 bool disable_compressed_tensor_optimization,
203 bool fold_quantization_ops)
204 : ConstantFolding(RewriterConfig::ON, cpu_device,
205 disable_compressed_tensor_optimization,
206 fold_quantization_ops) {}
207
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210 GraphDef* graph,
211 NodeMap* node_map) {
212 if (IsControlInput(input_name)) {
213 return input_name;
214 }
215 const NodeDef& node = *node_map->GetNode(input_name);
216 if (!IsSwitch(node)) {
217 return AsControlDependency(node);
218 } else {
219 // We can't anchor control dependencies directly on the switch node: unlike
220 // other nodes only one of the outputs of the switch node will be generated
221 // when the switch node is executed, and we need to make sure the control
222 // dependency is only triggered when the corresponding output is triggered.
223 // We start by looking for an identity node connected to the output of the
224 // switch node, and use it to anchor the control dependency.
225 for (const NodeDef* output : node_map->GetOutputs(node.name())) {
226 if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
227 if (IsSameInput(node.input(0), input_name)) {
228 return AsControlDependency(*output);
229 }
230 }
231 }
232 // We haven't found an existing node where we can anchor the control
233 // dependency: add a new identity node.
234 int port = 0;
235 string ctrl_dep_name = ParseNodeName(input_name, &port);
236 strings::StrAppend(&ctrl_dep_name, "_", port);
237 ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
238 const DataType output_type = node.attr().at("T").type();
239
240 NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
241 if (added_node == nullptr) {
242 added_node = graph->add_node();
243 added_node->set_name(ctrl_dep_name);
244 added_node->set_op("Identity");
245 added_node->set_device(node.device());
246
247 (*added_node->mutable_attr())["T"].set_type(output_type);
248 *added_node->add_input() = input_name;
249 node_map->AddNode(added_node->name(), added_node);
250 node_map->AddOutput(node.name(), added_node->name());
251 }
252 return AsControlDependency(*added_node);
253 }
254 }
255
256 // Forward inputs at the given indices to outputs and add a control dependency
257 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)258 bool ConstantFolding::ForwardInputs(NodeDef* node,
259 absl::Span<const int> inputs_to_forward) {
260 for (int input_idx : inputs_to_forward) {
261 if (input_idx < 0 || input_idx >= node->input_size()) {
262 return false;
263 }
264 }
265
266 const auto& tmp = node_map_->GetOutputs(node->name());
267 const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
268 bool updated_graph = false;
269 for (int input_idx : inputs_to_forward) {
270 const string& input = node->input(input_idx);
271 if (IsControlInput(input) && consumers.size() > 1) {
272 continue;
273 }
274 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
275 if (input_node == nullptr) {
276 LOG(ERROR) << "Bad input: " << input;
277 break;
278 }
279 // Update each consumer.
280 for (NodeDef* consumer : consumers) {
281 bool add_dep = false;
282 for (int consumer_input_idx = 0;
283 consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
284 const string& consumer_input = consumer->input(consumer_input_idx);
285 if (IsControlInput(consumer_input)) {
286 break;
287 }
288 // It is illegal to add control dependencies to _Retval nodes, so we
289 // can't bypass value producing `node` and forward inputs to `consumer`.
290 if (IsRetval(*consumer)) {
291 break;
292 }
293 int output_idx;
294 const string input_node_name =
295 ParseNodeName(consumer_input, &output_idx);
296 if (input_node_name == node->name() && output_idx == input_idx) {
297 consumer->set_input(consumer_input_idx, input);
298 // We will keep the input from the node through a control
299 // dependency, so we only need to add the consumer as an output
300 // for the input node.
301 node_map_->AddOutput(NodeName(input), consumer->name());
302 add_dep = true;
303 }
304 }
305 if (add_dep) {
306 consumer->add_input(AsControlDependency(node->name()));
307 updated_graph = true;
308 }
309 }
310 }
311
312 if (updated_graph) {
313 for (NodeDef* consumer : consumers) {
314 DedupControlInputs(consumer);
315 }
316 }
317 return updated_graph;
318 }
319
320 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64 value,const DataType & type,const int index,Tensor * tensor)321 static Status PutValueIntoTensor(const int64 value, const DataType& type,
322 const int index, Tensor* tensor) {
323 if (type == DT_INT32) {
324 if (value >= INT_MAX) {
325 return Status(error::INVALID_ARGUMENT, "int32 overflow");
326 }
327 tensor->flat<int32>()(index) = static_cast<int32>(value);
328 } else {
329 tensor->flat<int64>()(index) = value;
330 }
331 return Status::OK();
332 }
333
334 // Writes the given tensor shape into the given tensor.
335 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)336 static Status ConvertShapeToConstant(const string& op, const DataType& type,
337 const PartialTensorShape& shp,
338 Tensor* tensor) {
339 if (op == "Shape" || op == "ShapeN") {
340 *tensor = Tensor(type, TensorShape({shp.dims()}));
341 for (int i = 0; i < shp.dims(); ++i) {
342 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
343 }
344 } else if (op == "Size") {
345 int64 size = 1;
346 for (int i = 0; i < shp.dims(); ++i) {
347 size *= shp.dim_size(i);
348 }
349 *tensor = Tensor(type, TensorShape({}));
350 TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
351 } else {
352 CHECK_EQ(op, "Rank");
353 *tensor = Tensor(type, TensorShape({}));
354 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
355 }
356 return Status::OK();
357 }
358
359 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const360 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
361 StringPiece suffix) const {
362 return node_map_->NodeExists(OptimizedNodeName(node, suffix));
363 }
364
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const365 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
366 StringPiece suffix) const {
367 return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
368 kConstantFoldingConst);
369 }
370
IsReallyConstant(const NodeDef & node) const371 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
372 if (!IsConstant(node)) {
373 return false;
374 }
375 // If the node is fed it's not constant anymore.
376 return feed_nodes_.find(node.name()) == feed_nodes_.end();
377 }
378
379 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)380 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
381 Tensor* tensor) {
382 const NodeDef* node = node_map_->GetNode(node_name_or_input);
383 return node != nullptr && IsReallyConstant(*node) &&
384 CheckAttrExists(*node, "value").ok() &&
385 tensor->FromProto(node->attr().at("value").tensor());
386 }
387
388 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)389 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
390 // We may add some nodes to the graph to encode control dependencies and hold
391 // the materialized shapes: there is no need to process these added nodes, so
392 // only iterate over the nodes of the input graph.
393 const int node_count = graph_->node_size();
394 for (int node_idx = 0; node_idx < node_count; ++node_idx) {
395 NodeDef* node = graph_->mutable_node(node_idx);
396 const string op = node->op();
397 if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
398 op != "TensorArraySizeV3") {
399 continue;
400 }
401 const std::vector<OpInfo::TensorProperties>& output =
402 properties.GetOutputProperties(node->name());
403 const std::vector<OpInfo::TensorProperties>& input =
404 properties.GetInputProperties(node->name());
405 if (input.empty() || output.empty()) {
406 continue;
407 }
408
409 if (op == "Shape" || op == "Size" || op == "Rank") {
410 CHECK_EQ(1, output.size());
411 CHECK_EQ(1, input.size());
412
413 const DataType type = output[0].dtype();
414 CHECK(type == DT_INT32 || type == DT_INT64);
415 const PartialTensorShape shape(input[0].shape());
416
417 if ((op != "Rank" && !shape.IsFullyDefined()) ||
418 (op == "Rank" && shape.unknown_rank())) {
419 continue;
420 }
421
422 Tensor constant_value(type);
423 if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
424 continue;
425 }
426
427 // TODO(rmlarsen): Remove this workaround for b/150861569
428 // The bug involves an expression of the form Shape(ExpandDims(x)
429 // with an incorrectly inferred zero-size first dimension.
430 if (op == "Shape") {
431 if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
432 }
433
434 // Repurpose the existing node to be the constant.
435 // Device placement is preserved.
436 graph_modified_ = true;
437 node->set_op("Const");
438 EraseRegularNodeAttributes(node);
439 (*node->mutable_attr())["dtype"].set_type(type);
440 constant_value.AsProtoTensorContent(
441 (*node->mutable_attr())["value"].mutable_tensor());
442
443 // Turn the data input into a control dependency: this is needed to
444 // ensure that the constant value will only be run in the
445 // cases where the shape/rank/size would have been run in
446 // the original graph.
447 string ctrl_dep =
448 AddControlDependency(node->input(0), graph_, node_map_.get());
449 node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
450 node->set_input(0, ctrl_dep);
451 // Done with the Shape/Size/Rank node, move to the next node.
452 continue;
453 }
454
455 if (op == "TensorArraySizeV3") {
456 const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
457 if (array->input_size() == 0 ||
458 (array->attr().count("dynamic_size") != 0 &&
459 array->attr().at("dynamic_size").b())) {
460 continue;
461 }
462 const NodeDef* array_size =
463 CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
464 if (IsReallyConstant(*array_size)) {
465 // Don't materialize 0 sizes to avoid triggering incorrect static
466 // checks. A 0 sized array that can't grow isn't useful anyway.
467 if (array_size->attr().count("value") == 0) {
468 continue;
469 }
470 const TensorProto& raw_val = array_size->attr().at("value").tensor();
471 if (raw_val.dtype() != DT_INT32) {
472 continue;
473 }
474 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
475 if (!value.FromProto(raw_val)) {
476 continue;
477 }
478 if (value.flat<int32>()(0) == 0) {
479 continue;
480 }
481
482 graph_modified_ = true;
483 node->set_op("Const");
484 *node->mutable_attr() = array_size->attr();
485 node->set_input(0, AsControlDependency(NodeName(node->input(0))));
486 node->set_input(1, AddControlDependency(NodeName(node->input(1)),
487 graph_, node_map_.get()));
488 }
489 continue;
490 }
491
492 // Handle ShapeN materialization case.
493 // It's possible that not all input tensors have known shapes.
494 CHECK_EQ(op, "ShapeN");
495 CHECK_EQ(input.size(), output.size());
496 const NodeDef* const shape_n_node = node;
497 for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
498 ++port_idx) {
499 const DataType type = output[port_idx].dtype();
500 CHECK(type == DT_INT32 || type == DT_INT64);
501 const PartialTensorShape shape(input[port_idx].shape());
502 if (!shape.IsFullyDefined()) {
503 continue;
504 }
505 Tensor constant_value(type);
506 auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
507 if (!status.ok()) {
508 continue;
509 }
510
511 // We make a copy because we mutate the nodes.
512 auto fanouts = node_map_->GetOutputs(shape_n_node->name());
513 // Find all nodes consuming this shape and connect them through the new
514 // constant node instead.
515 for (NodeDef* output : fanouts) {
516 // Track whether there are any direct edges left between shape_n_node
517 // and this output node after the transformation.
518 bool direct_edges_exist = false;
519 for (int k = 0; k < output->input_size(); ++k) {
520 int port;
521 const string node_name = ParseNodeName(output->input(k), &port);
522 if (node_name == shape_n_node->name() && port == port_idx) {
523 // Create a const node as ShapeN's output if not already.
524 const string const_name = OptimizedNodeName(
525 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
526 if (node_map_->GetNode(const_name) == nullptr) {
527 NodeDef* added_node = graph_->add_node();
528 added_node->set_name(const_name);
529 added_node->set_op("Const");
530 added_node->set_device(shape_n_node->device());
531 node_map_->AddNode(added_node->name(), added_node);
532 (*added_node->mutable_attr())["dtype"].set_type(type);
533 constant_value.AsProtoTensorContent(
534 (*added_node->mutable_attr())["value"].mutable_tensor());
535 // We add a control dependency to the original ShapeN node,
536 // so that the node will only be run if all inputs of the
537 // original ShapeN node are run.
538 string ctrl_dep = AddControlDependency(shape_n_node->name(),
539 graph_, node_map_.get());
540 *added_node->add_input() = ctrl_dep;
541 node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
542 }
543 *output->mutable_input(k) = const_name;
544 node_map_->AddOutput(const_name, output->name());
545 graph_modified_ = true;
546 }
547 if (node_name == shape_n_node->name() && port != port_idx) {
548 direct_edges_exist = true;
549 }
550 }
551 if (!direct_edges_exist) {
552 node_map_->RemoveOutput(node->name(), output->name());
553 }
554 }
555 }
556 }
557
558 return Status::OK();
559 }
560
561 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)562 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
563 BCast::Vec* shape, int64* min_id) {
564 if (shape_node.op() == "Shape") {
565 const std::vector<OpInfo::TensorProperties>& prop1 =
566 properties.GetInputProperties(shape_node.name());
567 if (prop1.size() != 1) {
568 return false;
569 }
570 const TensorShapeProto& shp = prop1[0].shape();
571 if (shp.unknown_rank()) {
572 return false;
573 }
574 for (const auto& dim : shp.dim()) {
575 shape->push_back(dim.size());
576 *min_id = std::min<int64>(*min_id, dim.size());
577 }
578 } else {
579 if (shape_node.attr().count("value") == 0) {
580 return false;
581 }
582 const TensorProto& raw_val = shape_node.attr().at("value").tensor();
583 if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
584 return false;
585 }
586 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
587 if (!value.FromProto(raw_val)) {
588 return false;
589 }
590 for (int j = 0; j < value.NumElements(); ++j) {
591 if (raw_val.dtype() == DT_INT64) {
592 shape->push_back(value.vec<int64>()(j));
593 } else {
594 shape->push_back(value.vec<int>()(j));
595 }
596 }
597 }
598 return true;
599 }
600 } // namespace
601
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)602 Status ConstantFolding::MaterializeBroadcastGradientArgs(
603 const NodeDef& node, const GraphProperties& properties) {
604 const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
605 const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
606 if (shape_node1 == nullptr ||
607 (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
608 shape_node2 == nullptr ||
609 (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
610 return Status::OK();
611 }
612
613 // Don't optimize this again if it was already optimized and folded.
614 if (OptimizedNodeExists(node, "-folded-1") ||
615 OptimizedNodeExists(node, "-folded-2")) {
616 return Status::OK();
617 }
618 int64 min_id = 0;
619 BCast::Vec shape1;
620 if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
621 return Status::OK();
622 }
623 BCast::Vec shape2;
624 if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
625 return Status::OK();
626 }
627 // A value of -1 means we don't known anything about the dimension. Replace
628 // the -1 values with unique dimension ids since we don't want two '-1'
629 // dimensions to be considered equal.
630 for (auto& id : shape1) {
631 if (id == -1) {
632 id = --min_id;
633 }
634 }
635 for (auto& id : shape2) {
636 if (id == -1) {
637 id = --min_id;
638 }
639 }
640
641 // Beware: the reduction dimensions computed by the BCast class are valid iff
642 // we assume that two distinct symbolic dimensions can't be equal and a
643 // symbolic dimension can't be equal to 1. This is often but not always true,
644 // so to make this optimization safe we filter out these cases.
645 const int common_dims = std::min(shape1.size(), shape2.size());
646 for (int i = 0; i < common_dims; ++i) {
647 if (shape1[i] >= 0 && shape2[i] >= 0) {
648 continue;
649 }
650 if (shape1[i] != shape2[i]) {
651 // We're either dealing with 2 different symbolic dimensions or a symbolic
652 // and a know dimensions. We can't be sure whether both are equal or not,
653 // so we can't be sure whether we'll be broadcasting or not.
654 return Status::OK();
655 }
656 }
657 // These extra dims could be equal to 1, in which case there is no
658 // broadcasting. It could also be greater than 1, in which case there would
659 // be broadcasting. Since we don't know, we'll just punt.
660 for (int i = common_dims, end = shape1.size(); i < end; ++i) {
661 if (shape1[i] < 0) {
662 return Status::OK();
663 }
664 }
665 for (int i = common_dims, end = shape2.size(); i < end; ++i) {
666 if (shape2[i] < 0) {
667 return Status::OK();
668 }
669 }
670
671 BCast bcast(shape1, shape2);
672 if (!bcast.IsValid()) {
673 return Status::OK();
674 }
675
676 BCast::Vec reduce_dims[2];
677 reduce_dims[0] = bcast.grad_x_reduce_idx();
678 reduce_dims[1] = bcast.grad_y_reduce_idx();
679
680 TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
681 const DataType type = node.attr().at("T").type();
682 NodeDef* out[2];
683 for (int j = 0; j < 2; ++j) {
684 int reduction_indices = reduce_dims[j].size();
685 Tensor value(type, TensorShape({reduction_indices}));
686 for (int i = 0; i < reduction_indices; ++i) {
687 if (type == DT_INT32) {
688 value.vec<int32>()(i) = reduce_dims[j][i];
689 } else {
690 value.vec<int64>()(i) = reduce_dims[j][i];
691 }
692 }
693 string const_name =
694 OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
695 out[j] = node_map_->GetNode(const_name);
696 if (out[j] == nullptr) {
697 out[j] = graph_->add_node();
698 TF_RETURN_IF_ERROR(
699 CreateNodeDef(const_name, TensorValue(&value), out[j]));
700 out[j]->set_device(node.device());
701 node_map_->AddNode(const_name, out[j]);
702 string ctrl_dep =
703 AddControlDependency(node.name(), graph_, node_map_.get());
704 *out[j]->add_input() = ctrl_dep;
705 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
706 }
707 }
708
709 // We make a copy here since we might mutate the set.
710 const auto outputs = node_map_->GetOutputs(node.name());
711 for (NodeDef* output : outputs) {
712 for (int k = 0; k < output->input_size(); ++k) {
713 int port;
714 string node_name = ParseNodeName(output->input(k), &port);
715 if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
716 *output->mutable_input(k) = out[port]->name();
717 node_map_->UpdateInput(output->name(), node_name, out[port]->name());
718 }
719 }
720 }
721
722 return Status::OK();
723 }
724
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)725 Status ConstantFolding::MaterializeReductionIndices(
726 NodeDef* node, const GraphProperties& properties) {
727 if (node->input_size() < 2) {
728 return Status::OK();
729 }
730 const NodeDef* indices = node_map_->GetNode(node->input(1));
731 if (!indices || IsReallyConstant(*indices)) {
732 // The reduction indices are already constant, there's nothing to do.
733 return Status::OK();
734 }
735
736 const std::vector<OpInfo::TensorProperties>& input_props =
737 properties.GetInputProperties(node->name());
738 if (input_props.size() != 2) {
739 return Status::OK();
740 }
741 const OpInfo::TensorProperties& input_prop = input_props[0];
742 if (input_prop.shape().unknown_rank()) {
743 // We can't do anything if we don't know the rank of the input.
744 return Status::OK();
745 }
746 const int input_rank = input_prop.shape().dim_size();
747 if (input_rank < 1) {
748 // Unexpected graph, don't try to change it.
749 return Status::OK();
750 }
751 const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
752 DataType dtype = reduction_indices_prop.dtype();
753 if (dtype != DT_INT32 && dtype != DT_INT64) {
754 return Status::OK();
755 }
756 PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
757 const int num_reduction_indices = reduction_indices_shape.num_elements();
758
759 const std::vector<OpInfo::TensorProperties>& output_props =
760 properties.GetOutputProperties(node->name());
761 if (output_props.size() != 1) {
762 return Status::OK();
763 }
764 const OpInfo::TensorProperties& output_prop = output_props[0];
765 const int output_rank =
766 output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
767
768 bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
769 if (!full_reduction) {
770 // A full reduction will generate a tensor of one of the shapes
771 // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
772 // elements in the output of the reduction, we may deduce it from reshape
773 // nodes following it.
774 for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
775 full_reduction = false;
776 if (!IsReshape(*fanout)) {
777 return Status::OK();
778 }
779 const std::vector<OpInfo::TensorProperties>& reshape_props =
780 properties.GetOutputProperties(fanout->name());
781 if (reshape_props.size() != 1) {
782 return Status::OK();
783 }
784 const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
785 PartialTensorShape shape(reshape_prop.shape());
786 if (shape.num_elements() != 1) {
787 return Status::OK();
788 } else {
789 full_reduction = true;
790 }
791 }
792 if (!full_reduction) {
793 return Status::OK();
794 }
795 }
796
797 // We know it's a full reduction. We can generate the full set of indices to
798 // reduce as a constant node.
799 string const_name = OptimizedNodeName(*node, "-reduction_indices");
800 if (node_map_->GetNode(const_name)) {
801 return Status::OK();
802 }
803 NodeDef* reduction_indices = graph_->add_node();
804 Tensor value(dtype, TensorShape({input_rank}));
805 for (int i = 0; i < input_rank; ++i) {
806 if (dtype == DT_INT32) {
807 value.vec<int32>()(i) = i;
808 } else {
809 value.vec<int64>()(i) = i;
810 }
811 }
812 TF_RETURN_IF_ERROR(
813 CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
814
815 reduction_indices->set_device(node->device());
816 string ctrl_dep =
817 AddControlDependency(node->input(1), graph_, node_map_.get());
818 *reduction_indices->add_input() = ctrl_dep;
819 node_map_->AddNode(const_name, reduction_indices);
820 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
821
822 node->set_input(1, reduction_indices->name());
823 node_map_->UpdateInput(node->name(), indices->name(),
824 reduction_indices->name());
825
826 return Status::OK();
827 }
828
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)829 Status ConstantFolding::MaterializeConstantValuedNode(
830 NodeDef* node, const GraphProperties& properties) {
831 if (disable_compressed_tensor_optimization_) {
832 return Status::OK();
833 }
834 // Nodes that generate constant-valued outputs can be represented compactly in
835 // compressed format, regardless of their shape.
836 const std::vector<OpInfo::TensorProperties>& output_props =
837 properties.GetOutputProperties(node->name());
838 if (output_props.size() != 1) return Status::OK();
839 const auto& output_shape = output_props[0].shape();
840 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
841 return Status::OK();
842 }
843 if (IsFill(*node)) {
844 const auto output_dtype = output_props[0].dtype();
845 NodeDef* input_node = nullptr;
846 for (int i = 0; i < 2; ++i) {
847 input_node = node_map_->GetNode(NodeName(node->input(i)));
848 if (input_node == nullptr || !IsReallyConstant(*input_node)) {
849 return Status::OK();
850 }
851 }
852 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
853
854 // Copy the input tensor to the fill node, set the output shape and data
855 // type, and change the node type to Const.
856 TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
857 const TensorProto& input_tensor = input_node->attr().at("value").tensor();
858 if (!input_tensor.tensor_content().empty()) {
859 // Convert the value to repeated field format, so we can use the
860 // decompression mechanism to store only a single value in the constant
861 // node, even if the shape specified in the original Fill is large.
862 Tensor t;
863 if (!t.FromProto(input_tensor)) {
864 return errors::InvalidArgument(
865 "Could not construct Tensor form TensorProto in node: ",
866 input_node->name());
867 }
868 tensor->clear_tensor_content();
869 t.AsProtoField(tensor);
870 } else {
871 *tensor = input_tensor;
872 }
873 *(tensor->mutable_tensor_shape()) = output_shape;
874 (*node->mutable_attr())["dtype"].set_type(output_dtype);
875 node->mutable_attr()->erase("T");
876 node->mutable_attr()->erase("index_type");
877 node->set_op("Const");
878 for (int i = 0; i < 2; i++) {
879 // Change inputs to a control inputs.
880 const string ctrl_dep = AsControlDependency(node->input(i));
881 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
882 node->set_input(i, ctrl_dep);
883 }
884 graph_modified_ = true;
885 } else {
886 double value =
887 (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
888 if (value >= 0) {
889 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
890 value, properties, output_shape, node, graph_));
891 }
892 }
893 return Status::OK();
894 }
895
896 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)897 Status ConstantFolding::MaterializeOutputValues(
898 NodeDef* node, const GraphProperties& properties) {
899 const std::vector<OpInfo::TensorProperties>& output =
900 properties.GetOutputProperties(node->name());
901 if (output.size() != 1 || !output[0].has_value() ||
902 !IsFoldable(*node, &properties)) {
903 return Status::OK();
904 }
905
906 // If this is a trivial Identity node with a constant input, just route the
907 // input around it.
908 if (IsIdentity(*node)) {
909 NodeDef* input = node_map_->GetNode(node->input(0));
910 if (IsReallyConstant(*input)) {
911 std::vector<int> inputs_to_forward;
912 std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
913 graph_modified_ = ForwardInputs(node, inputs_to_forward);
914 return Status::OK();
915 }
916 }
917 // Repurpose the existing node to be the constant.
918 // Device placement is preserved.
919 TensorProto value_copy = output[0].value();
920 return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
921 node, graph_);
922 }
923
MaterializeConstants(const GraphProperties & properties)924 Status ConstantFolding::MaterializeConstants(
925 const GraphProperties& properties) {
926 const int node_count = graph_->node_size();
927 for (int i = 0; i < node_count; ++i) {
928 NodeDef& node = *graph_->mutable_node(i);
929 const string& op = node.op();
930 if (op == "BroadcastGradientArgs") {
931 TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
932 } else if (IsReduction(node)) {
933 TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
934 } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
935 TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
936 } else {
937 TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
938 }
939 }
940 return Status::OK();
941 }
942
IsFoldable(const NodeDef & node,const GraphProperties * properties)943 bool ConstantFolding::IsFoldable(const NodeDef& node,
944 const GraphProperties* properties) {
945 string key = strings::StrCat(node.name(), "/", node.op());
946 auto it = maybe_foldable_nodes_.find(key);
947 if (it == maybe_foldable_nodes_.end()) {
948 it = maybe_foldable_nodes_
949 .emplace(std::move(key), MaybeFoldable(node, properties))
950 .first;
951 }
952 if (!it->second) {
953 return false;
954 } else {
955 return IsFoldableUncached(node, properties);
956 }
957 }
958
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const959 bool ConstantFolding::IsFoldableUncached(
960 const NodeDef& node, const GraphProperties* properties) const {
961 // Folding not applicable to ops with no inputs.
962 if (node.input().empty()) {
963 return false;
964 }
965 // We can only fold nodes if all their inputs are known statically, except in
966 // the case of a merge node that propagate the first inputs that becomes
967 // available, and therefore only requires a single constant input to be
968 // foldable.
969 bool merge_has_constant_input = false;
970 const bool is_merge = IsMerge(node);
971 for (const auto& input : node.input()) {
972 if (IsControlInput(input)) {
973 continue;
974 }
975 const NodeDef* input_node = node_map_->GetNode(input);
976 if (!input_node) {
977 return false;
978 }
979 bool is_const = IsReallyConstant(*input_node);
980 if (is_const) {
981 // Don't fold strings constants for now since this causes problems with
982 // checkpointing.
983 if (input_node->attr().count("dtype") == 0 ||
984 input_node->attr().at("dtype").type() == DT_STRING) {
985 return false;
986 }
987 // Special case: If a Merge node has at least one constant input that
988 // does not depend on a control input, we can fold it.
989 merge_has_constant_input |= !HasControlInputs(*input_node);
990 } else if (!is_merge) {
991 return false;
992 }
993 }
994 if (is_merge && !merge_has_constant_input) return false;
995 if (disable_compressed_tensor_optimization_ &&
996 (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
997 return false;
998
999 // If we know the output shapes, make sure that the outputs are small enough
1000 // to materialize.
1001 if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1002 const std::vector<OpInfo::TensorProperties>& input_props =
1003 properties->GetInputProperties(node.name());
1004 const std::vector<OpInfo::TensorProperties>& output_props =
1005 properties->GetOutputProperties(node.name());
1006 // Compute total size of inputs.
1007 int64 input_size_bytes = 0;
1008 for (const auto& input_prop : input_props) {
1009 const PartialTensorShape input_shape(input_prop.shape());
1010 if (input_shape.IsFullyDefined()) {
1011 input_size_bytes +=
1012 input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1013 }
1014 }
1015 for (const auto& output_prop : output_props) {
1016 const PartialTensorShape output_shape(output_prop.shape());
1017 if (output_shape.IsFullyDefined()) {
1018 const int64 num_bytes =
1019 output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1020 if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1021 // Do not fold nodes if the in-memory size of output is too large.
1022 // Notice that this is not exactly the same check used in
1023 // CreateNodeDef() where the actual encoded size is checked.
1024 return false;
1025 }
1026 }
1027 }
1028 }
1029
1030 return true;
1031 }
1032
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1033 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1034 const GraphProperties* properties) const {
1035 // Skip constants, they're already folded
1036 if (IsConstant(node)) {
1037 return false;
1038 }
1039 // Don't fold stateful ops such as TruncatedNormal.
1040 if (!IsFreeOfSideEffect(node)) {
1041 return false;
1042 }
1043
1044 // Skips nodes that must be preserved except allowlisted nodes.
1045 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1046 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1047 return false;
1048 }
1049
1050 // Skip control flow nodes, they can't be folded.
1051 if (ModifiesFrameInfo(node)) {
1052 return false;
1053 }
1054
1055 // Skips ops that don't benefit from folding.
1056 if (IsPlaceholder(node)) {
1057 return false;
1058 }
1059 // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1060 // have a valid output when executed.
1061 if (IsFakeParam(node)) {
1062 return false;
1063 }
1064
1065 if (node.op() == "AccumulateNV2") {
1066 return false;
1067 }
1068 // Removing LoopCond nodes can screw up the partitioner.
1069 if (node.op() == "LoopCond") {
1070 return false;
1071 }
1072
1073 if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1074 return false;
1075 }
1076
1077 const string& op = node.op();
1078 if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1079 op.find("Reader") != string::npos) {
1080 return false;
1081 }
1082 if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1083 return false;
1084 }
1085
1086 // Don't fold nodes that contain TPU attributes.
1087 // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1088 // properly forward custom attributes, b/119051778.
1089 if (HasTPUAttributes(node)) {
1090 return false;
1091 }
1092
1093 const OpDef* op_def = nullptr;
1094 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1095 if (!status.ok()) {
1096 return false;
1097 }
1098 // Don't fold ops without outputs.
1099 if (op_def->output_arg_size() == 0) {
1100 return false;
1101 }
1102 // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1103 // TODO(rmlarsen): Only do this for XLA_* devices.
1104 for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1105 if (output_arg.type() == DT_VARIANT) {
1106 return false;
1107 }
1108 }
1109
1110 // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1111 // Such nodes could be introduced by an earlier constant folding pass and are
1112 // preserved in case users want to fetch their values; re-processing them
1113 // would lead to an error of adding a duplicated node to graph.
1114 const auto& outputs = node_map_->GetOutputs(node.name());
1115 if (outputs.empty() &&
1116 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1117 return false;
1118 }
1119 return true;
1120 }
1121
1122 namespace {
1123
1124 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
1125 case DTYPE: \
1126 t->add_##NAME##_val(static_cast<TYPE>(value)); \
1127 break;
1128
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1129 Status CreateConstantTensorAttrValue(DataType type, double value,
1130 const TensorShapeProto& shape,
1131 AttrValue* attr_tensor) {
1132 TensorProto* t = attr_tensor->mutable_tensor();
1133 t->set_dtype(type);
1134 *t->mutable_tensor_shape() = shape;
1135 switch (type) {
1136 case DT_HALF:
1137 t->add_half_val(static_cast<Eigen::half>(value).x);
1138 break;
1139 case DT_BFLOAT16:
1140 t->add_half_val(static_cast<bfloat16>(value).value);
1141 break;
1142 SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1143 SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1144 SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
1145 SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
1146 SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1147 SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1148 SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1149 SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1150 SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1151 SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1152 SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1153 SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1154 SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1155 SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1156 SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1157 SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1158 default:
1159 return errors::InvalidArgument(
1160 "Unsupported type in CreateConstantTensorAttrValue: ",
1161 DataTypeString(type));
1162 }
1163 return Status::OK();
1164 }
1165
1166 #undef SET_TENSOR_CAL_CASE
1167
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1168 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1169 const GraphProperties& properties) {
1170 DataType dtype = DT_INVALID;
1171 if (node.attr().count("T") == 1) {
1172 dtype = node.attr().at("T").type();
1173 } else if (node.attr().count("dtype") == 1) {
1174 dtype = node.attr().at("dtype").type();
1175 } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1176 dtype = DT_BOOL;
1177 } else {
1178 auto output_props = properties.GetOutputProperties(node.name());
1179 if (!output_props.empty()) {
1180 dtype = output_props[0].dtype();
1181 }
1182 }
1183 return dtype;
1184 }
1185
1186 // Checks whether the shape of the const input of the Mul op is valid to perform
1187 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1188 bool IsValidConstShapeForMulConvPushDown(
1189 const string& data_format, const TensorShapeProto& filter_shape,
1190 const TensorShapeProto& mul_const_input_shape) {
1191 // If the const is a scalar, or it has fewer or same number of dimensions
1192 // than the filter and it only has single element, the optimization should
1193 // work.
1194 if (mul_const_input_shape.dim_size() <=
1195 static_cast<int>(data_format.size()) &&
1196 TensorShape(mul_const_input_shape).num_elements() == 1) {
1197 return true;
1198 }
1199
1200 // Otherwise, check the eligibility according to data format.
1201 if (data_format == "NHWC" || data_format == "NDHWC") {
1202 TensorShapeProto new_filter_shape;
1203 if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1204 &new_filter_shape)) {
1205 return false;
1206 }
1207 if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1208 return false;
1209 }
1210 // Only the last dimension could be larger than one, since broadcasting over
1211 // the last dimension (the output channel) will result in invalid filter.
1212 for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1213 if (mul_const_input_shape.dim(i).size() > 1) return false;
1214 }
1215 return true;
1216 } else if (data_format == "NCHW" || data_format == "NCDHW") {
1217 // TODO(laigd): support NCHW and NCDHW (b/111214513).
1218 return false;
1219 }
1220 return false;
1221 }
1222
1223 } // namespace
1224
1225 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1226 Status ConstantFolding::CreateNodeDef(const string& name,
1227 const TensorValue& tensor, NodeDef* node,
1228 size_t original_size) {
1229 node->set_name(name);
1230 node->set_op("Const");
1231
1232 AttrValue attr_type;
1233 attr_type.set_type(tensor->dtype());
1234 node->mutable_attr()->insert({"dtype", attr_type});
1235
1236 AttrValue attr_tensor;
1237 TensorProto* t = attr_tensor.mutable_tensor();
1238 bool optimized = false;
1239 size_t encoded_size;
1240 // Use the packed representation whenever possible to avoid generating large
1241 // graphdefs. Moreover, avoid repeating the last values if they're equal.
1242 if (tensor->NumElements() > 4) {
1243 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE) \
1244 { \
1245 const auto* val_ptr = tensor->flat<TYPE>().data(); \
1246 auto last = *val_ptr; \
1247 int64 last_index = 0; \
1248 for (int64 i = 0; i < tensor->NumElements(); ++i) { \
1249 TYPE cur = *val_ptr++; \
1250 if (PackedValuesNotEqual(cur, last)) { \
1251 last = cur; \
1252 last_index = i; \
1253 } \
1254 } \
1255 encoded_size = (last_index + 1) * sizeof(FIELDTYPE); \
1256 if (encoded_size < kint32max) { \
1257 optimized = true; \
1258 t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1); \
1259 const auto* src_ptr = tensor->flat<TYPE>().data(); \
1260 auto* dst_ptr = \
1261 t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1262 std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr); \
1263 } \
1264 } \
1265 break
1266
1267 switch (tensor->dtype()) {
1268 case DT_FLOAT:
1269 POPULATE_TENSOR_PROTO(tensor, t, float, float);
1270 case DT_DOUBLE:
1271 POPULATE_TENSOR_PROTO(tensor, t, double, double);
1272 case DT_INT64:
1273 POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
1274 case DT_UINT64:
1275 POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1276 case DT_INT32:
1277 POPULATE_TENSOR_PROTO(tensor, t, int32, int);
1278 case DT_UINT32:
1279 POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1280 case DT_INT16:
1281 POPULATE_TENSOR_PROTO(tensor, t, int16, int);
1282 case DT_UINT16:
1283 POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1284 case DT_INT8:
1285 POPULATE_TENSOR_PROTO(tensor, t, int8, int);
1286 case DT_UINT8:
1287 POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1288 case DT_BOOL:
1289 POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1290 default:
1291 /* Do nothing. */
1292 break;
1293 }
1294 }
1295 if (optimized) {
1296 // Also specify type and shape.
1297 t->set_dtype(tensor->dtype());
1298 tensor->shape().AsProto(t->mutable_tensor_shape());
1299 } else {
1300 // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1301 // DT_QUINT8
1302 tensor->AsProtoTensorContent(t);
1303 encoded_size = t->tensor_content().size();
1304 }
1305 node->mutable_attr()->insert({"value", attr_tensor});
1306
1307 if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1308 return errors::InvalidArgument(
1309 strings::StrCat("Can't fold ", name, ", its size would be too large (",
1310 encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1311 }
1312 return Status::OK();
1313 }
1314
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1315 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1316 const TensorVector& inputs,
1317 TensorVector* output) const {
1318 return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1319 resource_mgr_.get(), output);
1320 }
1321
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1322 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1323 std::vector<NodeDef>* outputs,
1324 bool* result_too_large) {
1325 TensorVector inputs;
1326 TensorVector output_tensors;
1327 auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1328 for (const auto& input : inputs) {
1329 delete input.tensor;
1330 }
1331 for (const auto& output : output_tensors) {
1332 if (output.tensor) {
1333 delete output.tensor;
1334 }
1335 }
1336 });
1337
1338 size_t total_inputs_size = 0;
1339 for (const auto& input : node.input()) {
1340 const TensorId input_tensor = ParseTensorName(input);
1341 if (input_tensor.index() < 0) {
1342 // Control dependency
1343 break;
1344 }
1345 const NodeDef* input_node = node_map_->GetNode(input);
1346 if (!IsReallyConstant(*input_node)) {
1347 return Status(error::INVALID_ARGUMENT,
1348 strings::StrCat("Can't fold ", node.name(), ", its ", input,
1349 " isn't constant"));
1350 }
1351 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1352 const TensorProto& raw_val = input_node->attr().at("value").tensor();
1353 Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1354 CHECK(value->FromProto(raw_val))
1355 << "Unable to make Tensor from proto for " << node.name()
1356 << " with shape " << raw_val.tensor_shape().DebugString();
1357 inputs.emplace_back(value);
1358 total_inputs_size += value->TotalBytes();
1359 }
1360
1361 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1362 if (output_tensors.empty()) {
1363 return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1364 }
1365
1366 outputs->resize(output_tensors.size());
1367 for (size_t i = 0; i < output_tensors.size(); i++) {
1368 string node_name = OptimizedNodeName(node, "-folded");
1369 if (output_tensors.size() > 1) {
1370 node_name = strings::StrCat(node_name, "-", i);
1371 }
1372 if (output_tensors[i].tensor) {
1373 Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1374 total_inputs_size);
1375 if (!s.ok()) {
1376 *result_too_large = true;
1377 return s;
1378 }
1379 } else {
1380 // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1381 // switch that's not selected by the switch predicate).
1382 outputs->at(i) = NodeDef();
1383 }
1384 }
1385 return Status::OK();
1386 }
1387
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1388 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1389 // Merge nodes are special, in the sense that they execute as soon as one of
1390 // their input is ready. We can therefore fold a merge node iff it has at
1391 // least one constant input without control dependency.
1392 // We still need to ensure that the nodes in the fanin of the merge node are
1393 // scheduled. We'll therefore add a control dependency from the merge node
1394 // to the folded constant. We end up with:
1395 // * the merge node and its inputs are preserved as is
1396 // * a new constant node C1, driven by the merge node through a control
1397 // dependency, initialized to the value of the folded input
1398 // * a new constant node C2, driven by the merge node through a control
1399 // dependency, initialized to the index of the folded input
1400 // * the fanout of the merge nodes is rewired to be driven by either C1 or
1401 // C2.
1402 for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1403 const auto& input = node->input(input_index);
1404 if (IsControlInput(input)) {
1405 // Try the next input.
1406 continue;
1407 }
1408 NodeDef* input_node = node_map_->GetNode(input);
1409 if (!IsReallyConstant(*input_node)) {
1410 continue;
1411 }
1412 bool valid_input = true;
1413 for (const string& fanin_of_input : input_node->input()) {
1414 if (IsControlInput(fanin_of_input)) {
1415 valid_input = false;
1416 break;
1417 }
1418 }
1419 if (!valid_input) {
1420 // Try the next input
1421 continue;
1422 }
1423
1424 string const_out_name = OptimizedNodeName(*node, "_const");
1425 string const_index_name = OptimizedNodeName(*node, "_index");
1426 if (node_map_->GetNode(const_out_name) ||
1427 node_map_->GetNode(const_index_name)) {
1428 // Intended name already exists.
1429 return errors::AlreadyExists(
1430 strings::StrCat(const_out_name, " or ", const_index_name,
1431 " already present in the graph"));
1432 }
1433
1434 NodeDef* const_out = output_graph->add_node();
1435 *const_out = *input_node;
1436 const_out->set_name(const_out_name);
1437 const_out->set_device(node->device());
1438 *const_out->add_input() = AsControlDependency(*node);
1439 node_map_->AddNode(const_out->name(), const_out);
1440 node_map_->AddOutput(node->name(), const_out->name());
1441
1442 NodeDef* const_index = output_graph->add_node();
1443 const_index->set_op("Const");
1444 Tensor index(DT_INT32, TensorShape({}));
1445 index.flat<int32>()(0) = input_index;
1446 (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1447 index.AsProtoTensorContent(
1448 (*const_index->mutable_attr())["value"].mutable_tensor());
1449 const_index->set_name(const_index_name);
1450 const_index->set_device(node->device());
1451 *const_index->add_input() = AsControlDependency(*node);
1452 node_map_->AddNode(const_index->name(), const_index);
1453 node_map_->AddOutput(node->name(), const_index->name());
1454
1455 // We make a copy because we mutate the nodes.
1456 auto outputs = node_map_->GetOutputs(node->name());
1457 for (NodeDef* output : outputs) {
1458 for (int i = 0; i < output->input_size(); i++) {
1459 int port;
1460 string node_name = ParseNodeName(output->input(i), &port);
1461 if (node_name == node->name()) {
1462 if (port == 0) {
1463 *output->mutable_input(i) = const_out->name();
1464 node_map_->AddOutput(const_out->name(), output->name());
1465 } else if (port == 1) {
1466 *output->mutable_input(i) = const_index->name();
1467 node_map_->AddOutput(const_index->name(), output->name());
1468 } else {
1469 // This is a control dependency (or an invalid edge since the
1470 // merge node has only 2 outputs): preserve them.
1471 }
1472 }
1473 }
1474 }
1475 return Status::OK();
1476 }
1477 return Status::OK();
1478 }
1479
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1480 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1481 bool* result_too_large) {
1482 *result_too_large = false;
1483 if (IsMerge(*node)) {
1484 return FoldMergeNode(node, output_graph);
1485 }
1486
1487 std::vector<NodeDef> const_nodes;
1488 TF_RETURN_IF_ERROR(
1489 EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1490 VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1491
1492 NodeDef* constant_output = nullptr;
1493 for (int i = 0, end = const_nodes.size(); i < end; i++) {
1494 NodeDef* const_node = &const_nodes[i];
1495 VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1496 if (const_node->name().empty()) {
1497 // Dead output: we can't create a constant to encode its value, so we'll
1498 // just skip it. We'll preserve the edges that originate from that
1499 // output below to preserve the overall behavior of the graph wrt dead
1500 // edges.
1501 continue;
1502 }
1503
1504 // Returns `true` iff `const_node` already has control input named `input`.
1505 const auto is_duplicate_control_input = [&](const string& input) -> bool {
1506 auto it = absl::c_find(const_node->input(), input);
1507 return it != const_node->input().end();
1508 };
1509
1510 // Forward control dependencies.
1511 for (const string& input : node->input()) {
1512 // Forward control dependencies from folded node.
1513 if (IsControlInput(input)) {
1514 if (!is_duplicate_control_input(input)) {
1515 *const_node->add_input() = input;
1516 }
1517 }
1518
1519 // Forward control dependencies from constant inputs to folded node.
1520 if (!IsControlInput(input)) {
1521 NodeDef* input_node = node_map_->GetNode(input);
1522 for (const string& fanin_of_input : input_node->input()) {
1523 if (!is_duplicate_control_input(fanin_of_input)) {
1524 *const_node->add_input() = fanin_of_input;
1525 }
1526 }
1527 }
1528 }
1529
1530 // We rewrite the existing node if it only has a single output, and
1531 // create new nodes otherwise.
1532 if (const_nodes.size() == 1) {
1533 node->set_op("Const");
1534 // Note we need to clear the inputs in NodeMap before we clear the inputs
1535 // in the node, otherwise NodeMap would see empty inputs and effectively
1536 // does nothing.
1537 node_map_->RemoveInputs(node->name());
1538 node->clear_input();
1539 *node->mutable_input() = const_node->input();
1540 for (const auto& input : node->input()) {
1541 node_map_->AddOutput(NodeName(input), node->name());
1542 }
1543 *node->mutable_attr() = const_node->attr();
1544 break;
1545 } else {
1546 if (node_map_->GetNode(const_node->name())) {
1547 // Intended name already exists.
1548 return errors::AlreadyExists(strings::StrCat(
1549 const_node->name(), " already present in the graph"));
1550 }
1551 NodeDef* added_node = output_graph->add_node();
1552 *added_node = *const_node;
1553 added_node->set_device(node->device());
1554 node_map_->AddNode(added_node->name(), added_node);
1555 for (const auto& input : added_node->input()) {
1556 node_map_->AddOutput(NodeName(input), added_node->name());
1557 }
1558 // All the constant nodes encoding output values have the same control
1559 // dependencies (since these are the control dependencies of the node
1560 // we're trying to fold). Record one such constant node.
1561 constant_output = added_node;
1562 }
1563 }
1564
1565 if (const_nodes.size() > 1) {
1566 // We make a copy because we mutate the nodes.
1567 auto outputs = node_map_->GetOutputs(node->name());
1568 for (NodeDef* output : outputs) {
1569 for (int i = 0; i < output->input_size(); i++) {
1570 int port;
1571 string node_name = ParseNodeName(output->input(i), &port);
1572 if (node_name == node->name()) {
1573 if (port < 0) {
1574 // Propagate control dependencies if possible. If not, we'll just
1575 // preserve the existing control dependencies.
1576 if (constant_output != nullptr) {
1577 node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1578 constant_output->name());
1579 *output->mutable_input(i) = AsControlDependency(*constant_output);
1580 }
1581 } else if (port < static_cast<int>(const_nodes.size()) &&
1582 !const_nodes[port].name().empty()) {
1583 // Replace alive outputs with the corresponding constant.
1584 node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1585 const_nodes[port].name());
1586 *output->mutable_input(i) = const_nodes[port].name();
1587 } else {
1588 // Leave this edge alone.
1589 VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1590 << "[" << node->op() << "] to " << output->name() << ":"
1591 << i << "[" << output->op() << "]";
1592 }
1593 }
1594 }
1595 }
1596 outputs = node_map_->GetOutputs(node->name());
1597 if (outputs.empty() && has_fetch_ &&
1598 nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1599 node_map_->RemoveInputs(node->name());
1600 node->clear_input();
1601 }
1602 }
1603 return Status::OK();
1604 }
1605
FoldGraph(const GraphProperties & properties,GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1606 Status ConstantFolding::FoldGraph(
1607 const GraphProperties& properties, GraphDef* output,
1608 absl::flat_hash_set<string>* nodes_to_not_simplify) {
1609 std::unordered_set<string> processed_nodes;
1610 std::deque<NodeDef*> queue;
1611 for (int i = 0; i < graph_->node_size(); i++) {
1612 bool foldable = IsFoldable(graph_->node(i), &properties);
1613 VLOG(2) << "foldable(" << graph_->node(i).name() << ") = " << foldable;
1614 if (foldable) {
1615 queue.push_back(graph_->mutable_node(i));
1616 }
1617 }
1618 while (!queue.empty()) {
1619 NodeDef* node = queue.front();
1620 queue.pop_front();
1621 if (processed_nodes.count(node->name())) {
1622 continue;
1623 }
1624 // We need to record a copy of output nodes before FoldNode() modifies it.
1625 // We also need to ensure that the fanout is sorted deterministically.
1626 std::vector<NodeDef*> fanout =
1627 node_map_->GetOutputsOrderedByNodeName(node->name());
1628 bool result_too_large = false;
1629 Status s = FoldNode(node, output, &result_too_large);
1630 processed_nodes.insert(node->name());
1631 if (!s.ok()) {
1632 VLOG(1) << "Failed to fold node " << node->DebugString()
1633 << "\nError message: " << s;
1634 if (result_too_large) {
1635 nodes_to_not_simplify->emplace(node->name());
1636 }
1637 } else {
1638 for (auto& output : fanout) {
1639 if (IsFoldable(*output, &properties)) {
1640 queue.push_back(output);
1641 }
1642 }
1643 }
1644 }
1645
1646 // Delete the newly created nodes that don't feed anything.
1647 std::vector<int> nodes_to_delete;
1648 for (int i = 0; i < output->node_size(); i++) {
1649 const auto& fanout = node_map_->GetOutputs(output->node(i).name());
1650 if (fanout.empty()) nodes_to_delete.push_back(i);
1651 }
1652 EraseNodesFromGraph(std::move(nodes_to_delete), output);
1653
1654 for (int i = 0; i < graph_->node_size(); ++i) {
1655 NodeDef* node = graph_->mutable_node(i);
1656 // If no fetch nodes is provided, we conservatively
1657 // move all nodes in the original graph to the output, in case users need
1658 // to fetch their values.
1659 const auto& fanout = node_map_->GetOutputs(node->name());
1660 if (!fanout.empty() || !has_fetch_ ||
1661 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1662 *(output->add_node()) = std::move(*node);
1663 }
1664 }
1665 return Status::OK();
1666 }
1667
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1668 bool ConstantFolding::IsSimplifiableReshape(
1669 const NodeDef& node, const GraphProperties& properties) const {
1670 if (!IsReshape(node)) {
1671 return false;
1672 }
1673 CHECK_LE(2, node.input_size());
1674 const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1675 if (!IsReallyConstant(*new_shape)) {
1676 return false;
1677 }
1678 TensorVector outputs;
1679 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1680 for (const auto& output : outputs) {
1681 delete output.tensor;
1682 }
1683 });
1684
1685 Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1686 if (!s.ok()) {
1687 return false;
1688 }
1689 CHECK_EQ(1, outputs.size());
1690
1691 const std::vector<OpInfo::TensorProperties>& props =
1692 properties.GetInputProperties(node.name());
1693 if (props.empty()) {
1694 return false;
1695 }
1696 const OpInfo::TensorProperties& prop = props[0];
1697 if (prop.dtype() == DT_INVALID) {
1698 return false;
1699 }
1700 const PartialTensorShape shape(prop.shape());
1701 if (!shape.IsFullyDefined()) {
1702 return false;
1703 }
1704
1705 PartialTensorShape new_dims;
1706 if (outputs[0]->dtype() == DT_INT32) {
1707 std::vector<int32> shp;
1708 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1709 int32 dim = outputs[0]->flat<int32>()(i);
1710 shp.push_back(dim);
1711 }
1712 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1713 } else {
1714 std::vector<int64> shp;
1715 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1716 int64 dim = outputs[0]->flat<int64>()(i);
1717 shp.push_back(dim);
1718 }
1719 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1720 }
1721
1722 return shape.IsCompatibleWith(new_dims);
1723 }
1724
1725 #define IS_VALUE_CASE(DTYPE, VALUE) \
1726 case DTYPE: \
1727 return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1728 node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1729
1730 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1731 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1732
IsOnes(const NodeDef & node) const1733 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1734 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1735 return false;
1736 }
1737 if (IsOnesLike(node)) return true;
1738 if (IsZerosLike(node)) return false;
1739 if (node.op() == "Fill") {
1740 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1741 return values != nullptr && IsOnes(*values);
1742 }
1743 if (node.op() != "Const") return false;
1744 if (node.attr().count("dtype") == 0) return false;
1745 const auto dtype = node.attr().at("dtype").type();
1746 switch (dtype) {
1747 IS_ONES_CASE(DT_BOOL);
1748 IS_ONES_CASE(DT_HALF);
1749 IS_ONES_CASE(DT_BFLOAT16);
1750 IS_ONES_CASE(DT_FLOAT);
1751 IS_ONES_CASE(DT_DOUBLE);
1752 IS_ONES_CASE(DT_COMPLEX64);
1753 IS_ONES_CASE(DT_COMPLEX128);
1754 IS_ONES_CASE(DT_UINT8);
1755 IS_ONES_CASE(DT_INT8);
1756 IS_ONES_CASE(DT_UINT16);
1757 IS_ONES_CASE(DT_INT16);
1758 IS_ONES_CASE(DT_INT32);
1759 IS_ONES_CASE(DT_INT64);
1760 IS_ONES_CASE(DT_QINT32);
1761 IS_ONES_CASE(DT_QINT16);
1762 IS_ONES_CASE(DT_QUINT16);
1763 IS_ONES_CASE(DT_QINT8);
1764 IS_ONES_CASE(DT_QUINT8);
1765 default:
1766 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1767 return false;
1768 }
1769 return false;
1770 }
1771
IsZeros(const NodeDef & node) const1772 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1773 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1774 return false;
1775 }
1776 if (IsOnesLike(node)) return false;
1777 if (IsZerosLike(node)) return true;
1778 if (node.op() == "Fill") {
1779 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1780 return values != nullptr && IsZeros(*values);
1781 }
1782 if (!IsConstant(node)) return false;
1783 if (node.attr().count("dtype") == 0) return false;
1784 const auto dtype = node.attr().at("dtype").type();
1785 switch (dtype) {
1786 IS_ZEROS_CASE(DT_BOOL);
1787 IS_ZEROS_CASE(DT_HALF);
1788 IS_ZEROS_CASE(DT_BFLOAT16);
1789 IS_ZEROS_CASE(DT_FLOAT);
1790 IS_ZEROS_CASE(DT_DOUBLE);
1791 IS_ZEROS_CASE(DT_COMPLEX64);
1792 IS_ZEROS_CASE(DT_COMPLEX128);
1793 IS_ZEROS_CASE(DT_UINT8);
1794 IS_ZEROS_CASE(DT_INT8);
1795 IS_ZEROS_CASE(DT_UINT16);
1796 IS_ZEROS_CASE(DT_INT16);
1797 IS_ZEROS_CASE(DT_INT32);
1798 IS_ZEROS_CASE(DT_INT64);
1799 IS_ZEROS_CASE(DT_QINT32);
1800 IS_ZEROS_CASE(DT_QINT16);
1801 IS_ZEROS_CASE(DT_QUINT16);
1802 IS_ZEROS_CASE(DT_QINT8);
1803 IS_ZEROS_CASE(DT_QUINT8);
1804 default:
1805 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1806 return false;
1807 }
1808 return false;
1809 }
1810
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1811 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1812 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1813 GraphDef* graph) {
1814 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1815 if (dtype == DT_INVALID) {
1816 return false;
1817 }
1818 const PartialTensorShape shape(
1819 properties.GetOutputProperties(node->name())[0].shape());
1820 if (!shape.IsFullyDefined()) {
1821 return false;
1822 }
1823 // Create constant node with shape.
1824 const string const_name = OptimizedNodeName(
1825 *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1826 if (node_map_->GetNode(const_name) != nullptr) {
1827 return false;
1828 }
1829
1830 Tensor shape_t;
1831 if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1832 return false;
1833 }
1834 NodeDef tmp;
1835 if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1836 return false;
1837 }
1838 NodeDef* const_node = graph->add_node();
1839 const_node->Swap(&tmp);
1840 const_node->set_device(node->device());
1841 node_map_->AddNode(const_name, const_node);
1842 for (int i = 0; i < node->input_size(); ++i) {
1843 if (i != input_to_broadcast) {
1844 // Add a control input on the unused input.
1845 string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1846 node_map_.get());
1847 *const_node->add_input() = ctrl_dep;
1848 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1849 }
1850 }
1851
1852 // Rewrite `node` in-place to BroadcastTo.
1853 node->set_op("BroadcastTo");
1854 EraseRegularNodeAttributes(node);
1855 (*node->mutable_attr())["T"].set_type(dtype);
1856 (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1857 // Set the designated input to BroadcastTo.
1858 node->mutable_input()->SwapElements(0, input_to_broadcast);
1859 // Keep all other inputs as control dependencies.
1860 for (int i = 1; i < node->input_size(); ++i) {
1861 if (IsControlInput(node->input(i))) {
1862 break;
1863 }
1864 const string ctrl_dep =
1865 AddControlDependency(node->input(i), graph, node_map_.get());
1866 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1867 node->set_input(i, ctrl_dep);
1868 }
1869 // Add the shape argument.
1870 *node->add_input() = const_node->name();
1871 node_map_->AddOutput(const_name, node->name());
1872 node->mutable_input()->SwapElements(1, node->input_size() - 1);
1873 return true;
1874 }
1875
1876 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1877 void ConstantFolding::ReplaceOperationWithIdentity(
1878 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1879 GraphDef* graph) {
1880 if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1881 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1882 if (dtype == DT_INVALID) return;
1883
1884 node->set_op("Identity");
1885 EraseRegularNodeAttributes(node);
1886 (*node->mutable_attr())["T"].set_type(dtype);
1887 // Propagate the designated input through the identity.
1888 node->mutable_input()->SwapElements(0, input_to_forward);
1889 // Add all other inputs as control dependencies.
1890 for (int i = 1; i < node->input_size(); ++i) {
1891 if (IsControlInput(node->input(i))) {
1892 break;
1893 }
1894 const string ctrl_dep =
1895 AddControlDependency(node->input(i), graph, node_map_.get());
1896 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1897 node->set_input(i, ctrl_dep);
1898 }
1899 graph_modified_ = true;
1900 }
1901
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1902 void ConstantFolding::ReplaceOperationWithSnapshot(
1903 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1904 GraphDef* graph) {
1905 // If the graph contains no ops that mutate their inputs, we can
1906 // use Identity instead of Snapshot.
1907 if (!graph_contains_assign_or_inplace_op_) {
1908 ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1909 return;
1910 }
1911
1912 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1913 if (dtype == DT_INVALID) return;
1914
1915 node->set_op("Snapshot");
1916 EraseRegularNodeAttributes(node);
1917 (*node->mutable_attr())["T"].set_type(dtype);
1918 // Propagate the designated input through the Snapshot.
1919 node->mutable_input()->SwapElements(0, input_to_forward);
1920 // Add all other inputs as control dependencies.
1921 for (int i = 1; i < node->input_size(); ++i) {
1922 if (IsControlInput(node->input(i))) {
1923 break;
1924 }
1925 const string ctrl_dep =
1926 AddControlDependency(node->input(i), graph, node_map_.get());
1927 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1928 node->set_input(i, ctrl_dep);
1929 }
1930 graph_modified_ = true;
1931 }
1932
1933 // Replace a node with NoOp. Change all inputs to control dependencies.
1934 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1935 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1936 GraphProperties* properties,
1937 GraphDef* graph) {
1938 if (HasRegularOutputs(*node, *node_map_)) return;
1939 node->set_op("NoOp");
1940 EraseRegularNodeAttributes(node);
1941 EraseNodeOutputAttributes(node);
1942 // Erase attributes that describe output properties.
1943 properties->ClearOutputProperties(node->name());
1944 // Change all inputs to control dependencies.
1945 for (int i = 0; i < node->input_size(); ++i) {
1946 if (IsControlInput(node->input(i))) {
1947 break;
1948 }
1949 const string ctrl_dep =
1950 AddControlDependency(node->input(i), graph, node_map_.get());
1951 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1952 node->set_input(i, ctrl_dep);
1953 }
1954 DedupControlInputs(node);
1955 graph_modified_ = true;
1956 }
1957
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1958 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
1959 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1960 GraphDef* graph) {
1961 if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
1962 graph)) {
1963 return;
1964 }
1965 graph_modified_ = true;
1966 }
1967
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1968 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1969 GraphDef* graph) {
1970 node->set_op("Reciprocal");
1971 node->mutable_input()->SwapElements(0, 1);
1972 const string ctrl_dep =
1973 AddControlDependency(node->input(1), graph, node_map_.get());
1974 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1975 node->set_input(1, ctrl_dep);
1976 graph_modified_ = true;
1977 }
1978
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1979 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1980 GraphDef* graph) {
1981 node->set_op("Neg");
1982 node->mutable_input()->SwapElements(0, 1);
1983 const string ctrl_dep =
1984 AddControlDependency(node->input(1), graph, node_map_.get());
1985 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1986 node->set_input(1, ctrl_dep);
1987 graph_modified_ = true;
1988 }
1989
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)1990 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
1991 TensorProto* value,
1992 NodeDef* node,
1993 GraphDef* graph) {
1994 if (dtype == DT_VARIANT) return Status::OK();
1995 node->set_op("Const");
1996 EraseRegularNodeAttributes(node);
1997 (*node->mutable_attr())["dtype"].set_type(dtype);
1998 (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
1999 // Convert all inputs to control dependencies.
2000 for (int i = 0; i < node->input_size(); ++i) {
2001 if (IsControlInput(node->input(i))) {
2002 break;
2003 }
2004 const string ctrl_dep =
2005 AddControlDependency(node->input(i), graph, node_map_.get());
2006 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2007 node->set_input(i, ctrl_dep);
2008 }
2009 DedupControlInputs(node);
2010 graph_modified_ = true;
2011 return Status::OK();
2012 }
2013
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2014 Status ConstantFolding::ReplaceOperationWithConstant(
2015 double value, const GraphProperties& properties,
2016 const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2017 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2018 if (dtype == DT_VARIANT) return Status::OK();
2019 AttrValue tensor_attr;
2020 Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2021 if (!s.ok()) {
2022 // Fail gracefully without mutating the graph.
2023 VLOG(1) << "Failed to replace node " << node->name() << " of type "
2024 << DataTypeString(dtype) << " with constant tensor of value "
2025 << value;
2026 return Status::OK();
2027 }
2028 return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2029 node, graph);
2030 }
2031
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2032 Status ConstantFolding::SimplifyGraph(
2033 bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
2034 absl::flat_hash_set<string>* nodes_to_not_simplify) {
2035 for (int i = 0; i < optimized_graph->node_size(); ++i) {
2036 NodeDef* node = optimized_graph->mutable_node(i);
2037 // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2038 // generalize to only restrict certain simplifications.
2039 if (nodes_to_not_simplify->find(node->name()) ==
2040 nodes_to_not_simplify->end()) {
2041 if (HasTPUAttributes(*node)) {
2042 nodes_to_not_simplify->insert(node->name());
2043 continue;
2044 }
2045
2046 TF_RETURN_IF_ERROR(
2047 SimplifyNode(use_shape_info, node, optimized_graph, properties));
2048 }
2049 }
2050 return Status::OK();
2051 }
2052
2053 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2054 TF_RETURN_IF_ERROR(EXPR); \
2055 if (graph_modified_) return Status::OK()
2056
2057 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2058 graph_modified_ = EXPR; \
2059 if (graph_modified_) return Status::OK()
2060
2061 #define RETURN_IF_MODIFIED(EXPR) \
2062 EXPR; \
2063 if (graph_modified_) return Status::OK()
2064
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2065 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
2066 GraphDef* optimized_graph,
2067 GraphProperties* properties) {
2068 bool graph_modified_cached = graph_modified_;
2069 graph_modified_ = false;
2070
2071 RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2072 RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2073 *properties, use_shape_info, optimized_graph, node));
2074 RETURN_IF_MODIFIED(
2075 RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2076 RETURN_IF_ERROR_OR_MODIFIED(
2077 RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2078 RETURN_IF_ERROR_OR_MODIFIED(
2079 SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2080 RETURN_IF_ERROR_OR_MODIFIED(
2081 SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2082 RETURN_IF_ERROR_OR_MODIFIED(
2083 SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2084 RETURN_IF_ERROR_OR_MODIFIED(
2085 SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2086 RETURN_IF_MODIFIED(
2087 SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2088 SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2089 SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2090 SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2091 SET_AND_RETURN_IF_MODIFIED(
2092 SimplifyReduction(optimized_graph, *properties, node));
2093 SET_AND_RETURN_IF_MODIFIED(
2094 SimplifyReshape(*properties, use_shape_info, node));
2095 RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2096 *properties, use_shape_info, optimized_graph, node));
2097 SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2098 SET_AND_RETURN_IF_MODIFIED(
2099 ConstantPushDown(properties, optimized_graph, node));
2100 SET_AND_RETURN_IF_MODIFIED(
2101 MulConvPushDown(optimized_graph, node, *properties));
2102 SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2103 SET_AND_RETURN_IF_MODIFIED(
2104 PartialAssocOpConstFolding(optimized_graph, properties, node));
2105 SET_AND_RETURN_IF_MODIFIED(
2106 MergeConcat(use_shape_info, properties, optimized_graph, node));
2107 SET_AND_RETURN_IF_MODIFIED(
2108 PartialConcatConstFolding(optimized_graph, properties, node));
2109 SET_AND_RETURN_IF_MODIFIED(
2110 ConstantPushDownBiasAdd(properties, optimized_graph, node));
2111 SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2112 SET_AND_RETURN_IF_MODIFIED(
2113 SimplifySelect(*properties, optimized_graph, node));
2114 RETURN_IF_MODIFIED(
2115 RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2116
2117 graph_modified_ = graph_modified_cached;
2118 return Status::OK();
2119 }
2120
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2121 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2122 GraphDef* optimized_graph,
2123 NodeDef* node) {
2124 if (node->attr().count("num_split") == 0) return;
2125 if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2126 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2127 }
2128 if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2129 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2130 }
2131 }
2132
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2133 Status ConstantFolding::RemoveShuffleOrTranspose(
2134 const GraphProperties& properties, bool use_shape_info,
2135 GraphDef* optimized_graph, NodeDef* node) {
2136 if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2137 return Status::OK();
2138 Tensor permutation_tensor;
2139 if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2140 properties.HasInputProperties(node->name())) {
2141 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2142 std::vector<int> permutation;
2143 for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2144 if (permutation_tensor.dtype() == DT_INT64) {
2145 permutation.push_back(permutation_tensor.vec<int64>()(j));
2146 } else {
2147 permutation.push_back(permutation_tensor.vec<int>()(j));
2148 }
2149 }
2150 int permutation_size = permutation.size();
2151 if (permutation_size != shape.dim_size()) {
2152 // Number of elements in perm should be same as dim_size. Skip if not.
2153 return Status::OK();
2154 }
2155 // The node is replaceable iff
2156 // dim_size == 0 || all dims have size 1 ||
2157 // all dims with > 1 size are not permuted.
2158 bool replaceable = true;
2159 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2160 replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2161 }
2162 if (replaceable) {
2163 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2164 }
2165 }
2166 return Status::OK();
2167 }
2168
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2169 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2170 bool use_shape_info,
2171 GraphDef* optimized_graph,
2172 NodeDef* node) {
2173 if (use_shape_info && IsRandomShuffle(*node) &&
2174 !properties.GetInputProperties(node->name()).empty()) {
2175 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2176 // The node is replaceable iff
2177 // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2178 if (!shape.unknown_rank() &&
2179 (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2180 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2181 }
2182 }
2183 }
2184
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2185 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2186 bool use_shape_info,
2187 GraphDef* optimized_graph,
2188 NodeDef* node) {
2189 if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
2190 Tensor axis;
2191 if (properties.HasInputProperties(node->name()) &&
2192 GetTensorFromConstNode(node->input(1), &axis)) {
2193 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2194 if (shape.unknown_rank()) return Status::OK();
2195 std::set<int> target_axes;
2196 for (int j = 0; j < axis.NumElements(); ++j) {
2197 // value of axis can be negative.
2198 if (axis.dtype() == DT_INT64) {
2199 target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2200 shape.dim_size());
2201 } else {
2202 target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2203 shape.dim_size());
2204 }
2205 }
2206
2207 // The node is replaceable iff
2208 // unknown_rank == false &&
2209 // (dim_size == 0 || all dims have size 1 ||
2210 // all dims with > 1 size are not in target_axes)
2211 bool replaceable = true;
2212 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2213 replaceable &=
2214 shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2215 }
2216 if (replaceable) {
2217 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2218 }
2219 }
2220 return Status::OK();
2221 }
2222
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2223 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2224 bool use_shape_info,
2225 GraphDef* optimized_graph,
2226 NodeDef* node) {
2227 if (!use_shape_info || !IsSlice(*node)) return Status::OK();
2228 Tensor begin;
2229 Tensor size;
2230 if (properties.HasInputProperties(node->name()) &&
2231 GetTensorFromConstNode(node->input(1), &begin) &&
2232 GetTensorFromConstNode(node->input(2), &size)) {
2233 const auto& input = properties.GetInputProperties(node->name())[0];
2234 // The node is replaceable iff unknown_rank == false &&
2235 // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2236 bool replaceable = !input.shape().unknown_rank();
2237 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2238 if (begin.dtype() == DT_INT32) {
2239 replaceable &= begin.vec<int>()(j) == 0;
2240 } else {
2241 replaceable &= begin.vec<int64>()(j) == 0;
2242 }
2243 if (size.dtype() == DT_INT32) {
2244 replaceable &= (size.vec<int>()(j) == -1 ||
2245 size.vec<int>()(j) == input.shape().dim(j).size());
2246 } else {
2247 replaceable &= (size.vec<int64>()(j) == -1 ||
2248 size.vec<int64>()(j) == input.shape().dim(j).size());
2249 }
2250 }
2251 if (replaceable) {
2252 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2253 }
2254 }
2255 return Status::OK();
2256 }
2257
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2258 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2259 bool use_shape_info,
2260 GraphDef* optimized_graph,
2261 NodeDef* node) {
2262 if (use_shape_info && IsStridedSlice(*node) &&
2263 properties.GetInputProperties(node->name()).size() == 4) {
2264 TF_RETURN_IF_ERROR(
2265 CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2266 if (node->attr().at("new_axis_mask").i() != 0 ||
2267 node->attr().at("shrink_axis_mask").i() != 0) {
2268 // Skip nodes with new/shrink axis mask, since they involve dimension
2269 // changes.
2270 return Status::OK();
2271 }
2272 const auto& input = properties.GetInputProperties(node->name())[0];
2273 for (int j = 0; j < input.shape().dim_size(); ++j) {
2274 // Skip if input shape is not fully determined.
2275 if (input.shape().dim(j).size() < 0) {
2276 return Status::OK();
2277 }
2278 }
2279
2280 std::vector<Tensor> input_tensors(3);
2281 for (int i = 1; i < 4; ++i) {
2282 if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2283 return Status::OK();
2284 }
2285 }
2286
2287 const Tensor& begin = input_tensors[0];
2288 const Tensor& end = input_tensors[1];
2289 const Tensor& strides = input_tensors[2];
2290
2291 TF_RETURN_IF_ERROR(
2292 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2293 int begin_mask = node->attr().at("begin_mask").i();
2294 int end_mask = node->attr().at("end_mask").i();
2295 std::set<int> expanded_ellipsis_indices;
2296 int ellipsis_index = -1;
2297 for (int j = 0; j < input.shape().dim_size(); ++j) {
2298 // find the ellipsis_mask. If not found, insert one in the end if
2299 // necessary.
2300 if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2301 (ellipsis_index == -1 && j >= strides.NumElements())) {
2302 ellipsis_index = j;
2303 }
2304 // insert the indices that are immediately after ellipsis_index if
2305 // necessary.
2306 if (ellipsis_index != -1 &&
2307 input.shape().dim_size() >
2308 strides.NumElements() + j - ellipsis_index) {
2309 expanded_ellipsis_indices.insert(j);
2310 }
2311 }
2312
2313 // The node is replaceable iff unknown_rank == false &&
2314 // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2315 // && strides == 1) for all dimensions.
2316 bool replaceable = !input.shape().unknown_rank();
2317 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2318 if (expanded_ellipsis_indices.find(j) !=
2319 expanded_ellipsis_indices.end()) {
2320 // ellipsis_mask is effective on current dimension.
2321 continue;
2322 }
2323 // when we have ellipsis_mask in between, input.shape().dim_size() will
2324 // be greater than strides.NumElements(), since we will insert
2325 // as many as expanded_ellipsis_indices.size() axes during computation.
2326 // We need to subtract this number from j.
2327 int i = j;
2328 int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2329 if (ellipsis_index != -1 &&
2330 j >= ellipsis_index + expanded_ellipsis_indices_size) {
2331 i = j - expanded_ellipsis_indices_size;
2332 }
2333 int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2334 : begin.vec<int64>()(i);
2335 int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2336 int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2337 : strides.vec<int64>()(i);
2338 replaceable &= (begin_mask & 1 << i || b == 0) &&
2339 (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2340 s == 1;
2341 }
2342 if (replaceable) {
2343 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2344 }
2345 }
2346 return Status::OK();
2347 }
2348
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2349 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2350 bool use_shape_info,
2351 GraphDef* optimized_graph, NodeDef* node) {
2352 Tensor multiplies;
2353 if (use_shape_info && IsTile(*node) &&
2354 GetTensorFromConstNode(node->input(1), &multiplies)) {
2355 // The node is replaceable iff all values in multiplies are 1.
2356 bool replaceable = true;
2357 if (multiplies.dtype() == DT_INT32) {
2358 for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2359 replaceable &= multiplies.vec<int>()(j) == 1;
2360 }
2361 } else {
2362 for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
2363 replaceable &= multiplies.vec<int64>()(j) == 1;
2364 }
2365 }
2366 if (replaceable) {
2367 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2368 }
2369 }
2370 return Status::OK();
2371 }
2372
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2373 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2374 bool use_shape_info,
2375 GraphDef* optimized_graph, NodeDef* node) {
2376 if (!use_shape_info || !IsPad(*node)) return Status::OK();
2377
2378 Tensor paddings;
2379 if (GetTensorFromConstNode(node->input(1), &paddings)) {
2380 // The node is replaceable iff all values in paddings are 0.
2381 bool replaceable = true;
2382 if (paddings.dtype() == DT_INT32) {
2383 const auto flatten = paddings.flat<int32>();
2384 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2385 replaceable &= flatten(j) == 0;
2386 }
2387 } else {
2388 const auto flatten = paddings.flat<int64>();
2389 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2390 replaceable &= flatten(j) == 0;
2391 }
2392 }
2393 if (replaceable) {
2394 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2395 }
2396 }
2397 return Status::OK();
2398 }
2399
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2400 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2401 bool use_shape_info,
2402 GraphDef* optimized_graph,
2403 NodeDef* node) {
2404 if (use_shape_info && IsSqueeze(*node) &&
2405 !properties.GetInputProperties(node->name()).empty()) {
2406 // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2407 // error to squeeze a dimension that is not 1, so we only need to check
2408 // whether the input has > 1 size for each dimension.
2409 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2410 // The node is replaceable iff
2411 // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2412 bool replaceable = !shape.unknown_rank();
2413 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2414 replaceable &= shape.dim(j).size() > 1;
2415 }
2416 if (replaceable) {
2417 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2418 }
2419 }
2420 }
2421
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2422 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2423 const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2424 if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2425 node_map_->NodeExists(axis_node_name)) {
2426 return false;
2427 }
2428
2429 // It's unsafe to add a control dependency on the feed node, because it might
2430 // have been never executed otherwiwise.
2431 if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2432 return false;
2433 }
2434
2435 // Create constant axis node.
2436 Tensor axis_t(DT_INT32, TensorShape({}));
2437 const int axis =
2438 node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2439 NodeDef new_node;
2440 if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2441 !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2442 return false;
2443 }
2444 NodeDef* axis_node = optimized_graph->add_node();
2445 *axis_node = std::move(new_node);
2446 axis_node->set_name(axis_node_name);
2447 node_map_->AddNode(axis_node->name(), axis_node);
2448 // Add a control dependency to make sure axis_node is in the right frame.
2449 const string ctrl_dep = ConstantFolding::AddControlDependency(
2450 node->input(0), optimized_graph, node_map_.get());
2451 axis_node->add_input(ctrl_dep);
2452 axis_node->set_device(node->device());
2453 node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2454 node->set_op("ExpandDims");
2455 if (node->attr().count("axis") != 0) {
2456 node->mutable_attr()->erase("axis");
2457 }
2458 if (node->attr().count("N") != 0) {
2459 node->mutable_attr()->erase("N");
2460 }
2461 (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2462 node->add_input(axis_node->name());
2463 node_map_->AddOutput(axis_node->name(), node->name());
2464 if (node->input_size() > 2) {
2465 node->mutable_input()->SwapElements(1, node->input_size() - 1);
2466 }
2467 return true;
2468 }
2469
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2470 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2471 if (node->op() != "Case") return false;
2472 const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2473 if (output_idx_node == nullptr ||
2474 !CheckAttrExists(*output_idx_node, "value").ok()) {
2475 return false;
2476 }
2477 Tensor output_idx_t;
2478 if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2479 return false;
2480 int output_idx = output_idx_t.scalar<int>()();
2481 const auto& func_list = node->attr().at("branches").list();
2482 if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2483 NodeDef call_node = *node;
2484 call_node.set_op("PartitionedCall");
2485 call_node.clear_input();
2486 for (int i = 1; i < node->input_size(); ++i) {
2487 call_node.add_input(node->input(i));
2488 }
2489 auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2490 *new_func = func_list.func(output_idx);
2491
2492 // Move the output shape of the branch to _output_shapes if it is known.
2493 const auto& output_shape_list =
2494 (*node->mutable_attr())["output_shapes"].list();
2495 if (output_shape_list.shape_size() > output_idx) {
2496 TensorShapeProto* new_output_shape =
2497 (*call_node.mutable_attr())["_output_shapes"]
2498 .mutable_list()
2499 ->add_shape();
2500 *new_output_shape =
2501 std::move(node->attr().at("output_shapes").list().shape(output_idx));
2502 }
2503
2504 call_node.mutable_attr()->erase("output_shapes");
2505 call_node.mutable_attr()->erase("branches");
2506
2507 *node = std::move(call_node);
2508 return true;
2509 }
2510
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2511 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2512 GraphDef* optimized_graph, NodeDef* node) {
2513 if (!IsSelect(*node)) return false;
2514 const std::vector<OpInfo::TensorProperties>& input_props =
2515 properties.GetInputProperties(node->name());
2516 if (input_props.size() < 3) return false;
2517 const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2518 const bool is_all_true = IsOnes(*predicate_node);
2519 const bool is_all_false = IsZeros(*predicate_node);
2520 if (!is_all_true && !is_all_false) {
2521 return false;
2522 }
2523 const int live_input_idx = is_all_true ? 1 : 2;
2524 const int ignored_input_idx = is_all_true ? 2 : 1;
2525 const TensorShapeProto& predicate_shape = input_props[0].shape();
2526 const bool predicate_is_scalar =
2527 !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2528 if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2529 (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2530 predicate_is_scalar)) {
2531 // Replace node with Identity if no broadcasting is involved.
2532 node->set_op("Identity");
2533 *node->mutable_input(0) =
2534 AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2535 *node->mutable_input(ignored_input_idx) = AddControlDependency(
2536 node->input(ignored_input_idx), optimized_graph, node_map_.get());
2537 node->mutable_input()->SwapElements(0, live_input_idx);
2538 } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2539 optimized_graph)) {
2540 return false;
2541 }
2542 DedupControlInputs(node);
2543 return true;
2544 }
2545
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2546 void ConstantFolding::RemoveRedundantVariableUpdates(
2547 GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2548 static const absl::flat_hash_set<string>* kVariableReadOps =
2549 new absl::flat_hash_set<string>{"AssignAddVariableOp",
2550 "AssignSubVariableOp",
2551 "AssignAdd",
2552 "AssignSub",
2553 "ScatterAdd",
2554 "ScatterSub",
2555 "ScatterMul",
2556 "ScatterDiv",
2557 "ScatterNdAdd",
2558 "ScatterNdSub",
2559 "ScatterNdMul",
2560 "ScatterNdDiv",
2561 "ResourceScatterAdd",
2562 "ResourceScatterSub",
2563 "ResourceScatterMul",
2564 "ResourceScatterDiv",
2565 "ResourceScatterNdAdd",
2566 "ResourceScatterNdSub",
2567 "ResourceScatterNdMul",
2568 "ResourceScatterNdDiv"};
2569 if (kVariableReadOps == nullptr ||
2570 kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2571 return;
2572 const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2573 const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2574 if (delta_node == nullptr) return;
2575 const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2576 absl::StrContains(node->op(), "Sub");
2577 if ((is_add_or_sub && IsZeros(*delta_node)) ||
2578 (!is_add_or_sub && IsOnes(*delta_node))) {
2579 VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2580 if (absl::StrContains(node->op(), "Variable") ||
2581 absl::StrContains(node->op(), "Resource")) {
2582 ReplaceOperationWithNoOp(node, properties, optimized_graph);
2583 } else {
2584 ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2585 optimized_graph);
2586 }
2587 }
2588 }
2589
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2590 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2591 NodeDef* node) {
2592 if (!IsEnter(*node) || node->input_size() == 0 ||
2593 node->attr().count("is_constant") == 0 ||
2594 !node->attr().at("is_constant").b()) {
2595 return false;
2596 }
2597 const string& node_name = node->name();
2598 const NodeDef* input = node_map_->GetNode(node->input(0));
2599 if (input == nullptr || !IsReallyConstant(*input) ||
2600 OptimizedNodeExists(*input, "_enter")) {
2601 return false;
2602 }
2603 // Find non-constant nodes that consume the output of *node.
2604 std::vector<NodeDef*> consumers;
2605 for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2606 if (!IsConstant(*fanout)) {
2607 for (int i = 0; i < fanout->input_size(); ++i) {
2608 if (fanout->input(i) == node_name) {
2609 consumers.push_back(const_cast<NodeDef*>(fanout));
2610 break;
2611 }
2612 }
2613 }
2614 }
2615 if (consumers.empty()) {
2616 return false;
2617 }
2618 graph_modified_ = true;
2619 NodeDef* new_node = optimized_graph->add_node();
2620 *new_node = *input;
2621 new_node->set_name(OptimizedNodeName(*input, "_enter"));
2622 new_node->set_device(node->device());
2623 new_node->clear_input();
2624 new_node->add_input(AsControlDependency(node_name));
2625 node_map_->AddNode(new_node->name(), new_node);
2626 node_map_->AddOutput(node_name, new_node->name());
2627 for (NodeDef* consumer : consumers) {
2628 for (int i = 0; i < consumer->input_size(); ++i) {
2629 if (NodeName(consumer->input(i)) == node_name) {
2630 node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2631 consumer->set_input(i, new_node->name());
2632 }
2633 }
2634 }
2635 return true;
2636 }
2637
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2638 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2639 if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2640 !OptimizedNodeExists(*node, "_const_false") &&
2641 !OptimizedNodeExists(*node, "_const_true")) {
2642 bool already_optimized = true;
2643 // If the optimization was already applied, the switch would have exactly
2644 // one Identity node consuming each of its outputs, each without any
2645 // non-control outputs.
2646 const auto& fanouts = node_map_->GetOutputs(node->name());
2647 if (fanouts.size() == 2) {
2648 for (const NodeDef* fanout : fanouts) {
2649 if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2650 HasRegularOutputs(*fanout, *node_map_)) {
2651 already_optimized = false;
2652 break;
2653 }
2654 }
2655 }
2656 Tensor false_t(DT_BOOL, TensorShape({}));
2657 Tensor true_t(DT_BOOL, TensorShape({}));
2658 // Make sure we don't proceed if this switch node was already optimized.
2659 if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2660 SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2661 // Copy the set of consumers of the switch as they will be manipulated
2662 // below.
2663 std::vector<NodeDef*> consumers =
2664 node_map_->GetOutputsOrderedByNodeName(node->name());
2665 // Create constant false & true nodes.
2666 NodeDef tmp_false_node;
2667 tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2668 if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2669 &tmp_false_node)
2670 .ok()) {
2671 return false;
2672 }
2673 tmp_false_node.set_device(node->device());
2674 NodeDef tmp_true_node;
2675 tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2676 if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2677 &tmp_true_node)
2678 .ok()) {
2679 return false;
2680 }
2681 tmp_true_node.set_device(node->device());
2682
2683 // Add const nodes to graph.
2684 NodeDef* false_node = optimized_graph->add_node();
2685 false_node->Swap(&tmp_false_node);
2686 NodeDef* true_node = optimized_graph->add_node();
2687 true_node->Swap(&tmp_true_node);
2688
2689 // Add controls from the switch ports to the constants, and connect the
2690 // constants to the original switch outputs.
2691 const string false_port = node->name();
2692 const string true_port = strings::StrCat(node->name(), ":1");
2693 const string false_ctrl_dep =
2694 AddControlDependency(false_port, optimized_graph, node_map_.get());
2695 false_node->add_input(false_ctrl_dep);
2696 const string true_ctrl_dep =
2697 AddControlDependency(true_port, optimized_graph, node_map_.get());
2698 true_node->add_input(true_ctrl_dep);
2699
2700 node_map_->AddNode(false_node->name(), false_node);
2701 node_map_->AddNode(true_node->name(), true_node);
2702 node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2703 node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2704
2705 for (NodeDef* consumer : consumers) {
2706 for (int i = 0; i < consumer->input_size(); ++i) {
2707 const string& input = consumer->input(i);
2708 if (input == false_port) {
2709 consumer->set_input(i, false_node->name());
2710 node_map_->UpdateInput(consumer->name(), false_port,
2711 false_node->name());
2712 } else if (input == true_port) {
2713 consumer->set_input(i, true_node->name());
2714 node_map_->UpdateInput(consumer->name(), true_port,
2715 true_node->name());
2716 }
2717 }
2718 }
2719 return true;
2720 }
2721 }
2722 return false;
2723 }
2724
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2725 bool ConstantFolding::IsReductionWithConstantIndices(
2726 const NodeDef& node, bool* indices_is_empty) const {
2727 // Ensure its an appropriate Reduce node.
2728 if (!IsReduction(node) || node.input_size() < 2) {
2729 return false;
2730 }
2731 // Ensure that the axes to reduce by are constant.
2732 NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2733 if (!IsReallyConstant(*reductions_indices) ||
2734 !reductions_indices->attr().count("value")) {
2735 return false;
2736 }
2737 const TensorShapeProto& reduction_indices_shape =
2738 reductions_indices->attr().at("value").tensor().tensor_shape();
2739 *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2740 return true;
2741 }
2742
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2743 bool ConstantFolding::IsReductionCandidateForSimplification(
2744 const NodeDef& node, const GraphProperties& properties,
2745 TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2746 bool* is_single_element_op) const {
2747 // Get the properties of the input & output tensors and check if they both
2748 // contain a single element.
2749 if (!properties.HasInputProperties(node.name()) ||
2750 !properties.HasOutputProperties(node.name())) {
2751 return false;
2752 }
2753 const auto& input_props = properties.GetInputProperties(node.name())[0];
2754 const auto& output_props = properties.GetOutputProperties(node.name())[0];
2755 if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2756 !output_props.has_shape() || output_props.shape().unknown_rank()) {
2757 return false;
2758 }
2759 *input_tensor_shape = input_props.shape();
2760 *output_tensor_shape = output_props.shape();
2761 for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2762 if (input_tensor_shape->dim(i).size() < 0) {
2763 return false;
2764 }
2765 }
2766 for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2767 if (output_tensor_shape->dim(i).size() < 0) {
2768 return false;
2769 }
2770 }
2771 const int input_num_elements =
2772 TensorShape(*input_tensor_shape).num_elements();
2773 const int output_num_elements =
2774 TensorShape(*output_tensor_shape).num_elements();
2775 *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2776
2777 return true;
2778 }
2779
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2780 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2781 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2782 const TensorVector& reduction_indices_vector) const {
2783 int output_size = reduction_indices_vector[0]->NumElements();
2784 if (output_size == 0) {
2785 return true;
2786 }
2787
2788 if (!keep_dims) {
2789 return false;
2790 }
2791 bool simplifiable = true;
2792 for (int i = 0; i < output_size; ++i) {
2793 int64 dim;
2794 if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2795 dim = reduction_indices_vector[0]->flat<int32>()(i);
2796 } else {
2797 dim = reduction_indices_vector[0]->flat<int64>()(i);
2798 }
2799 if (dim < 0) {
2800 dim += input_shape.dim_size();
2801 }
2802 if (dim < 0 || dim >= input_shape.dim_size() ||
2803 input_shape.dim(dim).size() != 1) {
2804 simplifiable = false;
2805 break;
2806 }
2807 }
2808 return simplifiable;
2809 }
2810
ReplaceReductionWithIdentity(NodeDef * node) const2811 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2812 // Replace the reduction node with an identity node, that can be further
2813 // optimized by other passes.
2814 DataType output_type;
2815 if (node->attr().count("T") != 0) {
2816 output_type = node->attr().at("T").type();
2817 } else if (IsAny(*node) || IsAll(*node)) {
2818 output_type = DT_BOOL;
2819 } else {
2820 return false;
2821 }
2822 node->set_op("Identity");
2823 EraseRegularNodeAttributes(node);
2824 (*node->mutable_attr())["T"].set_type(output_type);
2825 *node->mutable_input(1) = AsControlDependency(node->input(1));
2826 return true;
2827 }
2828
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2829 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2830 const GraphProperties& properties,
2831 NodeDef* node) {
2832 bool indices_is_empty = false;
2833 if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2834 return false;
2835 }
2836 if (indices_is_empty) {
2837 return ReplaceReductionWithIdentity(node);
2838 }
2839 bool is_single_element_op = false;
2840 TensorShapeProto input_tensor_shape, output_tensor_shape;
2841 if (!IsReductionCandidateForSimplification(
2842 *node, properties, &input_tensor_shape, &output_tensor_shape,
2843 &is_single_element_op)) {
2844 return false;
2845 }
2846
2847 // Get the reduction indices.
2848 string reduction_indices_input = node->input(1);
2849 NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2850 TensorVector reduction_indices_vector;
2851 auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2852 for (const auto& out : reduction_indices_vector) {
2853 delete out.tensor;
2854 }
2855 });
2856 if (!EvaluateNode(*reduction_indices, TensorVector(),
2857 &reduction_indices_vector)
2858 .ok() ||
2859 reduction_indices_vector.size() != 1) {
2860 return false;
2861 }
2862
2863 bool keep_dims =
2864 node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2865 bool simplifiable_to_reshape =
2866 is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2867 bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2868 *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2869
2870 if (simplifiable_to_reshape) {
2871 // Const node to output shape.
2872 const int new_num_dimensions = output_tensor_shape.dim_size();
2873 Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2874 for (int i = 0; i < new_num_dimensions; i++) {
2875 tensor.flat<int>()(i) = 1;
2876 }
2877 TensorValue shape_value(&tensor);
2878 NodeDef* shape_node = optimized_graph->add_node();
2879 if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2880 shape_node)
2881 .ok()) {
2882 return false;
2883 }
2884 shape_node->set_device(node->device());
2885 node_map_->AddNode(shape_node->name(), shape_node);
2886 // Control dependency to ensure shape_node is in the correct frame.
2887 shape_node->add_input(AsControlDependency(reduction_indices_input));
2888 node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2889 // Optimize node to Reshape.
2890 node->set_op("Reshape");
2891 node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2892 node->set_input(1, shape_node->name());
2893 node->mutable_attr()->erase("keep_dims");
2894 node->mutable_attr()->erase("Tidx");
2895 AttrValue attr_type_indices;
2896 attr_type_indices.set_type(DT_INT32);
2897 (*node->mutable_attr())["Tshape"] = attr_type_indices;
2898 return true;
2899 } else if (simplifiable_to_identity) {
2900 return ReplaceReductionWithIdentity(node);
2901 }
2902 return false;
2903 }
2904
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2905 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2906 bool use_shape_info, NodeDef* node) {
2907 if (!use_shape_info || node->attr().count("T") == 0 ||
2908 !IsSimplifiableReshape(*node, properties)) {
2909 return false;
2910 }
2911 DataType output_type = node->attr().at("T").type();
2912 node->set_op("Identity");
2913 EraseRegularNodeAttributes(node);
2914 (*node->mutable_attr())["T"].set_type(output_type);
2915 *node->mutable_input(1) = AsControlDependency(node->input(1));
2916 return true;
2917 }
2918
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2919 Status ConstantFolding::SimplifyArithmeticOperations(
2920 const GraphProperties& properties, bool use_shape_info,
2921 GraphDef* optimized_graph, NodeDef* node) {
2922 const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2923 const bool is_matmul = IsAnyMatMul(*node);
2924 const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2925 const bool is_sub = IsSub(*node);
2926 const bool is_any_div = IsAnyDiv(*node);
2927 // Simplify arithmetic operations with ones or zeros.
2928 if (use_shape_info &&
2929 (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2930 properties.HasInputProperties(node->name()) &&
2931 properties.HasOutputProperties(node->name())) {
2932 const NodeDef* x = node_map_->GetNode(node->input(0));
2933 const NodeDef* y = node_map_->GetNode(node->input(1));
2934 if (x == nullptr || y == nullptr) {
2935 return errors::InvalidArgument("Invalid inputs to node: ",
2936 node->DebugString());
2937 }
2938 const TensorShapeProto& output_shape =
2939 properties.GetOutputProperties(node->name())[0].shape();
2940
2941 // Simplify element-wise multiplication by ones or addition/subtraction
2942 // of zeros.
2943 const TensorShapeProto& y_shape =
2944 properties.GetInputProperties(node->name())[1].shape();
2945 const TensorShapeProto& x_shape =
2946 properties.GetInputProperties(node->name())[0].shape();
2947 const bool y_matches_output_shape =
2948 ShapesSymbolicallyEqual(output_shape, y_shape);
2949 const bool x_matches_output_shape =
2950 ShapesSymbolicallyEqual(output_shape, x_shape);
2951
2952 const bool x_is_zero = IsZeros(*x);
2953 const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2954 if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2955 // 1 * y = y or 0 + y = y.
2956 if (y_matches_output_shape) {
2957 ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2958 } else if (x_matches_output_shape) {
2959 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
2960 optimized_graph);
2961 }
2962 return Status::OK();
2963 }
2964
2965 if (y_matches_output_shape && (is_sub && x_is_zero)) {
2966 // Replace 0 - y with Neg(y).
2967 ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2968 return Status::OK();
2969 }
2970
2971 // Replace 1 / y with Reciprocal op.
2972 if (y_matches_output_shape && is_any_div && x_is_one) {
2973 TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2974 DataType type = node->attr().at("T").type();
2975 if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2976 ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2977 return Status::OK();
2978 }
2979 }
2980
2981 const bool y_is_zero = IsZeros(*y);
2982 const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2983 if (((is_mul || is_any_div) && y_is_one) ||
2984 ((is_add || is_sub) && y_is_zero)) {
2985 // x * 1 = x or x / 1 = x or x +/- 0 = x
2986 if (x_matches_output_shape) {
2987 ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
2988 } else if (y_matches_output_shape) {
2989 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
2990 optimized_graph);
2991 }
2992 return Status::OK();
2993 }
2994
2995 // x OR true = true OR y = true.
2996 const PartialTensorShape shp(output_shape);
2997 if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
2998 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
2999 1, properties, output_shape, node, optimized_graph));
3000 return Status::OK();
3001 }
3002
3003 // Simplify multiplication and matmul by zeros.
3004 // Also optimize zeros divided by a tensor, but only if we are in
3005 // aggressive mode, since we might get rid of divisions by zero.
3006 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3007 bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3008 if ((x_is_zero || y_is_zero) &&
3009 (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3010 if (shp.IsFullyDefined()) {
3011 bool is_quantized = IsQuantizedMatMul(*node);
3012 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3013 0, properties, output_shape, node, optimized_graph));
3014 if (is_quantized && graph_modified_) {
3015 TF_RETURN_IF_ERROR(
3016 AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3017 }
3018 return Status::OK();
3019 }
3020 // Even if an input shape is only partially known, we may known that it
3021 // matches the output shape and thus forward or broadcast the
3022 // corresponding zero input.
3023 if ((is_mul || is_any_div) && x_is_zero) {
3024 if (x_matches_output_shape) {
3025 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3026 } else if (y_matches_output_shape) {
3027 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3028 optimized_graph);
3029 }
3030 return Status::OK();
3031 } else if (is_mul && y_is_zero) {
3032 if (y_matches_output_shape) {
3033 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3034 } else if (x_matches_output_shape) {
3035 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3036 optimized_graph);
3037 }
3038 return Status::OK();
3039 }
3040 }
3041 }
3042 return Status::OK();
3043 }
3044
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3045 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3046 NodeDef* node) {
3047 // Strength reduce floating point division by a constant Div(x, const) to
3048 // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3049 // will be constant folded to Mul(x, 1.0/const).
3050 if (node->input_size() >= 2 &&
3051 (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3052 const string& const_input = node->input(1);
3053 const NodeDef* denom = node_map_->GetNode(const_input);
3054 CHECK(denom != nullptr);
3055 if (!IsReallyConstant(*denom)) {
3056 return false;
3057 }
3058 if (node->attr().count("T") == 0) {
3059 return false;
3060 }
3061 DataType type = node->attr().at("T").type();
3062 // Skip integer division.
3063 if (IsDiv(*node) &&
3064 !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3065 return false;
3066 }
3067 // Insert new reciprocal op and change node from Div to Mul.
3068 NodeDef* reciprocal_node = optimized_graph->add_node();
3069 reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3070 reciprocal_node->set_op("Reciprocal");
3071 reciprocal_node->set_device(node->device());
3072 reciprocal_node->add_input(const_input);
3073 (*reciprocal_node->mutable_attr())["T"].set_type(type);
3074
3075 // Re-wire inputs and outputs.
3076 if (IsXdivy(*node)) {
3077 node->set_op("MulNoNan");
3078 node->set_input(1, node->input(0));
3079 node->set_input(0, reciprocal_node->name());
3080 } else {
3081 node->set_op("Mul");
3082 node->set_input(1, reciprocal_node->name());
3083 }
3084 node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3085 node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3086
3087 return true;
3088 }
3089
3090 return false;
3091 }
3092
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3093 bool ConstantFolding::PrepareConstantPushDown(
3094 const NodeDef& parent, const GraphProperties& properties,
3095 bool must_have_properties, ConstantPushDownContext* ctx) const {
3096 if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3097 return false;
3098 }
3099 NodeDef* left_child = node_map_->GetNode(parent.input(0));
3100 NodeDef* right_child = node_map_->GetNode(parent.input(1));
3101 ctx->left_child_is_const = IsReallyConstant(*left_child);
3102 ctx->right_child_is_const = IsReallyConstant(*right_child);
3103 ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3104 ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3105
3106 // Nothing to do unless the parent has a constant child node.
3107 if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3108 return false;
3109 }
3110
3111 // Don't move nodes across devices.
3112 if (parent.device() != ctx->op_child->device() ||
3113 parent.device() != ctx->const_child->device()) {
3114 return false;
3115 }
3116
3117 // Make sure that it is safe to change the value of the child node result.
3118 if (ctx->op_child->input_size() < 2 ||
3119 nodes_to_preserve_.find(ctx->op_child->name()) !=
3120 nodes_to_preserve_.end() ||
3121 NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3122 return false;
3123 }
3124
3125 // Don't apply reassociation to floating point types of low precision.
3126 // The danger of significant numerical changes is too high.
3127 if (!CheckAttrExists(parent, "T").ok()) return false;
3128 DataType dtype = parent.attr().at("T").type();
3129 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3130 return false;
3131 }
3132
3133 // Don't rewrite the tree if it might create cycles.
3134 // TODO(rmlarsen): Add back handling of control dependency from op to C.
3135 const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3136 if (child_output.find(ctx->const_child) != child_output.end()) {
3137 return false;
3138 }
3139
3140 // Get leaf nodes.
3141 ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3142 ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3143 ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3144 ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3145
3146 if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3147 // Child is already foldable, leave it alone.
3148 return false;
3149 }
3150
3151 // Don't move nodes across devices.
3152 if (parent.device() != ctx->left_leaf->device() ||
3153 parent.device() != ctx->right_leaf->device()) {
3154 return false;
3155 }
3156
3157 // Get shape and type information.
3158 ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3159 ctx->op_child_input_props =
3160 &properties.GetInputProperties(ctx->op_child->name());
3161 if (must_have_properties && (ctx->parent_input_props == nullptr ||
3162 ctx->parent_input_props->size() < 2 ||
3163 ctx->op_child_input_props == nullptr ||
3164 ctx->op_child_input_props->size() < 2)) {
3165 return false;
3166 }
3167
3168 VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3169 << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3170 << ")";
3171
3172 return true;
3173 }
3174
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3175 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3176 GraphDef* optimized_graph,
3177 NodeDef* node) {
3178 // This implements constant push-down for BiasAdd. In the following "CV" is a
3179 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3180 // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3181 // non-constant matrix, and "BA" is BiasAdd.
3182 // For a valid input graph, the following 4 rewrites are legal:
3183 //
3184 // 1) + +
3185 // / \ / \
3186 // BA CV -- > BA V
3187 // / \ / \
3188 // M V M CV
3189 //
3190 // 2) + +
3191 // / \ / \
3192 // BA CM -- > BA M
3193 // / \ / \
3194 // M V CM V
3195 //
3196 // 3) BA BA
3197 // / \ / \
3198 // + CV -- > + V
3199 // / \ / \
3200 // M V M CV
3201 //
3202 // 4) BA BA = parent
3203 // / \ / \
3204 // BA CV -- > BA V = children
3205 // / \ / \
3206 // M V M CV = leaves
3207 //
3208 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3209
3210 const bool parent_is_bias_add = IsBiasAdd(*node);
3211 if (!parent_is_bias_add && !IsAdd(*node)) return false;
3212 ConstantPushDownContext ctx;
3213 if (!PrepareConstantPushDown(*node, *properties,
3214 /*must_have_properties=*/true, &ctx)) {
3215 return false;
3216 }
3217 // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3218 // >= 2 and the leaves must be vectors, we cannot swap them.
3219 if (ctx.left_child_is_const && parent_is_bias_add) return false;
3220 const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3221 if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3222
3223 // Get properties to validate rank and dtype constraints.
3224 if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3225 (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3226 (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3227 (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3228 (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3229 return false;
3230 }
3231
3232 // Now get the ranks and types of the 3 leaf nodes.
3233 const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3234 const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3235 // At least one leaf must be a vector.
3236 if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3237 const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3238 const int matrix_idx = 1 - vector_idx;
3239
3240 const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3241 const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3242 if (vector_rank != 1) return false; // this should never happen.
3243 const DataType vector_type = vector_prop.dtype();
3244
3245 const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3246 const int matrix_rank = matrix_prop.shape().dim_size();
3247 const DataType matrix_type = matrix_prop.dtype();
3248
3249 const int const_idx = ctx.left_child_is_const ? 0 : 1;
3250 const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3251 const int const_rank = const_prop.shape().dim_size();
3252 const DataType const_type = const_prop.dtype();
3253
3254 int input_to_swap = -1;
3255
3256 if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3257 const_type == matrix_type) {
3258 // Case 2:
3259 input_to_swap = matrix_idx;
3260 } else if (const_rank == 1 && const_type == vector_type) {
3261 // Case 1, 3, and, 4:
3262 input_to_swap = vector_idx;
3263 }
3264 if (input_to_swap == -1) return false;
3265 const NodeDef* leaf_to_swap =
3266 node_map_->GetNode(ctx.op_child->input(input_to_swap));
3267 if (IsConstant(*leaf_to_swap)) return false;
3268
3269 node_map_->UpdateInput(node->name(), node->input(const_idx),
3270 ctx.op_child->input(input_to_swap));
3271 node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3272 if (ctx.op_child->input(input_to_swap) !=
3273 ctx.op_child->input(1 - input_to_swap)) {
3274 node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3275 ctx.op_child->name());
3276 }
3277 std::swap(*node->mutable_input(const_idx),
3278 *ctx.op_child->mutable_input(input_to_swap));
3279 properties->ClearInputProperties(node->name());
3280 properties->ClearInputProperties(ctx.op_child->name());
3281
3282 return true;
3283 }
3284
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3285 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3286 GraphDef* optimized_graph,
3287 NodeDef* node) {
3288 // Consider the transformation
3289 //
3290 // + + = parent
3291 // / \ / \
3292 // C + -- > X + = children
3293 // / \ / \
3294 // X Y C Y = leaves
3295 //
3296 // where C is constant, X is non-constant, Y may be constant or non-constant,
3297 // and '+' denotes an associative and commutative operator like addition or
3298 // multiplication. This optimization pushes constants down in the tree to
3299 // canonicalize it. Moreover, in cases where the child node has a second
3300 // constant input Y we will create a leaf node that can be folded, e.g.
3301 //
3302 // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3303 //
3304 // We also handle the non-commutative cases of subtraction and division
3305 // by rotating the tree locally, e.g.
3306 // Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3307 // Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3308
3309 // Get parent op type.
3310 const bool is_add = IsAdd(*node);
3311 const bool is_mul = IsMul(*node);
3312 const bool is_sub = IsSub(*node);
3313 const bool is_div = IsDiv(*node);
3314 if (!(is_add || is_sub || is_mul || is_div)) return false;
3315 const bool is_symmetric = is_add || is_mul;
3316
3317 ConstantPushDownContext ctx;
3318 if (!PrepareConstantPushDown(*node, *properties,
3319 /*must_have_properties=*/false, &ctx)) {
3320 return false;
3321 }
3322
3323 // Get child op type.
3324 const bool is_child_add = IsAdd(*ctx.op_child);
3325 const bool is_child_mul = IsMul(*ctx.op_child);
3326 const bool is_child_sub = IsSub(*ctx.op_child);
3327 const bool is_child_div = IsDiv(*ctx.op_child);
3328 const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3329 const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3330 if (!is_add_sub && !is_mul_div) {
3331 return false;
3332 }
3333 const bool is_child_symmetric = is_child_add || is_child_mul;
3334
3335 if (!CheckAttrExists(*node, "T").ok()) return false;
3336 DataType dtype = node->attr().at("T").type();
3337 if (!(is_symmetric && is_child_symmetric) &&
3338 !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3339 return false;
3340 }
3341
3342 const NodeDef* y_node =
3343 ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3344 if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3345 !ctx.op_child_input_props->empty()) {
3346 // If we know the shapes of the nodes being swapped, make sure we don't push
3347 // down a larger node and create more work by broadcasting earlier in the
3348 // expressions tree.
3349 const PartialTensorShape c_shape(
3350 (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3351 const PartialTensorShape x_shape(
3352 (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3353
3354 if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3355 c_shape.num_elements() > x_shape.num_elements()) {
3356 return false;
3357 } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3358 c_shape.dims() > 0) {
3359 for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3360 if (x_shape.dim_size(idx) >= 0 &&
3361 c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3362 return false;
3363 }
3364 }
3365 }
3366 }
3367
3368 // Get the node names corresponding to X, Y, and C.
3369 const string input_x =
3370 ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3371 const string input_y = input_x == ctx.op_child->input(0)
3372 ? ctx.op_child->input(1)
3373 : ctx.op_child->input(0);
3374 const string input_c =
3375 ctx.left_child_is_const ? node->input(0) : node->input(1);
3376 const string input_op =
3377 ctx.left_child_is_const ? node->input(1) : node->input(0);
3378 VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3379
3380 // Now we have identified the nodes to swap, update the nodemap accordingly.
3381 node_map_->UpdateInput(node->name(), input_c, input_x);
3382 node_map_->AddOutput(input_c, ctx.op_child->name());
3383 if (input_x != input_y) {
3384 node_map_->RemoveOutput(input_x, ctx.op_child->name());
3385 }
3386 properties->ClearInputProperties(node->name());
3387 properties->ClearInputProperties(ctx.op_child->name());
3388
3389 if (is_symmetric && is_child_symmetric) {
3390 // Easy case (only commutative ops). We always write this as one of
3391 // +
3392 // / \
3393 // X +
3394 // / \
3395 // C Y
3396 node->set_input(0, input_x);
3397 node->set_input(1, input_op);
3398 ctx.op_child->set_input(0, input_c);
3399 ctx.op_child->set_input(1, input_y);
3400 } else {
3401 // More complicated case: When there are non-commutative operations like
3402 // subtractions or divisions involved, we may have to rotate the tree
3403 // and/or change op types. There are 6 non-trivial cases depending on
3404 // the effective generalized "sign" of each of the three terms C, Y, and X.
3405 // Here are the final trees we want to generate for those 6 cases:
3406 //
3407 // (CYX signs): ++- +-- -+- --+ +-+ -++
3408 //
3409 // - - - - + +
3410 // / \ / \ / \ / \ / \ / \
3411 // + X - X - X X + X - X -
3412 // / \ / \ / \ / \ / \ / \
3413 // C Y C Y Y C Y C C Y Y C
3414 //
3415
3416 // First, let's determine the effective sign of each term in the original
3417 // expression
3418 auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3419 bool leaf_negated = !is_child_symmetric && is_right_leaf;
3420 bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3421 return leaf_negated != child_negated;
3422 };
3423 const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3424 const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3425 bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3426 bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3427 bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3428 // Rewrite the parent node.
3429 node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3430 node->set_input(0, neg_x ? input_op : input_x);
3431 node->set_input(1, neg_x ? input_x : input_op);
3432 // Rewrite the child node.
3433 ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3434 ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3435 ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3436 }
3437 return true;
3438 }
3439
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3440 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3441 const GraphProperties& properties) {
3442 // Push down multiplication on ConvND.
3443 // * ConvND
3444 // / \ / \
3445 // ConvND C2 -- > X *
3446 // / \ / \
3447 // X C1 C1 C2
3448 //
3449 // where C1 and C2 are constants and X is non-constant.
3450 //
3451 // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3452
3453 if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3454
3455 NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3456 NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3457 // One child must be constant, and the second must be Conv op.
3458 const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3459 const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3460 if (!left_child_is_constant && !right_child_is_constant) {
3461 return false;
3462 }
3463 NodeDef* conv_node =
3464 left_child_is_constant ? mul_right_child : mul_left_child;
3465 if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3466 return false;
3467 }
3468 if (node->device() != mul_left_child->device() ||
3469 node->device() != mul_right_child->device()) {
3470 return false;
3471 }
3472
3473 // Make sure that it is safe to change the value of the convolution
3474 // output.
3475 if (conv_node->input_size() < 2 ||
3476 NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3477 nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3478 return false;
3479 }
3480
3481 // Identify the nodes to swap.
3482 NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3483 NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3484 const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3485 const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3486 if (!conv_left_is_constant && !conv_right_is_constant) {
3487 // At least one of the convolution inputs should be constant.
3488 return false;
3489 }
3490 if (conv_left_is_constant && conv_right_is_constant) {
3491 // Leverage regular constant folding to handle this.
3492 return false;
3493 }
3494 const auto& mul_props = properties.GetOutputProperties(node->name());
3495 const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3496 if (mul_props.empty() || conv_props.empty()) {
3497 return false;
3498 }
3499 const auto& mul_shape = mul_props[0].shape();
3500 const auto& conv_shape = conv_props[0].shape();
3501 if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3502 return false;
3503 }
3504
3505 const auto& input_props = properties.GetInputProperties(conv_node->name());
3506 if (input_props.size() < 2) {
3507 return false;
3508 }
3509 const auto& filter_shape = input_props[1].shape();
3510
3511 NodeDef* const_node =
3512 left_child_is_constant ? mul_left_child : mul_right_child;
3513 const auto& const_props = properties.GetOutputProperties(const_node->name());
3514 if (const_props.empty()) {
3515 return false;
3516 }
3517 const auto& const_shape = const_props[0].shape();
3518 if (!IsValidConstShapeForMulConvPushDown(
3519 conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3520 return false;
3521 }
3522
3523 string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3524 if (node_map_->NodeExists(mul_new_name)) {
3525 return false;
3526 }
3527 // Make sure we don't introduce loops in the graph by removing control
3528 // dependencies from the conv2d node to c2.
3529 string conv_const_input =
3530 conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3531 if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3532 node_map_.get())) {
3533 // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3534 MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3535 node_map_.get());
3536 }
3537
3538 conv_node->set_name(node->name());
3539 node->set_name(mul_new_name);
3540 if (conv_left_is_constant) {
3541 node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3542 conv_node->set_input(0, mul_new_name);
3543 } else {
3544 node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3545 conv_node->set_input(1, mul_new_name);
3546 }
3547 NodeDef* conv_const_node =
3548 conv_left_is_constant ? conv_left_child : conv_right_child;
3549 if (left_child_is_constant) {
3550 node->set_input(1, conv_const_node->name());
3551 } else {
3552 node->set_input(0, conv_const_node->name());
3553 }
3554 node_map_->AddNode(mul_new_name, node);
3555
3556 return true;
3557 }
3558
PartialConstPropThroughIdentityN(NodeDef * node)3559 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3560 // Partial constant propagation through IdentityN.
3561 if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3562 !HasRegularInputs(*node))
3563 return false;
3564
3565 std::vector<int> inputs_to_forward;
3566 for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3567 const string& input = node->input(input_idx);
3568 if (IsControlInput(input)) {
3569 return false;
3570 }
3571 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3572 if (input_node == nullptr) {
3573 LOG(ERROR) << "Bad input: " << input;
3574 return false;
3575 }
3576 // Forward constant inputs to outputs and add a control dependency on
3577 // the IdentityN node.
3578 if (IsReallyConstant(*input_node)) {
3579 inputs_to_forward.push_back(input_idx);
3580 }
3581 }
3582 return ForwardInputs(node, inputs_to_forward);
3583 }
3584
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3585 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3586 GraphProperties* properties,
3587 NodeDef* node) {
3588 // Partial constant folding for associative operators:
3589 // Split AddN/AccumulateNV2 to enable partial
3590 // folding of ops when more than one but not all inputs are constant.
3591 // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3592 // addition is commutative.
3593 if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3594
3595 const int num_non_control_inputs = NumNonControlInputs(*node);
3596 if (num_non_control_inputs <= 2) return false;
3597 const int num_control_inputs = node->input_size() - num_non_control_inputs;
3598 std::vector<int> const_inputs;
3599 std::vector<int> nonconst_inputs;
3600 for (int i = 0; i < node->input_size(); ++i) {
3601 const string& input = node->input(i);
3602 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3603 if (input_node == nullptr) return false;
3604 if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3605 const_inputs.push_back(i);
3606 } else {
3607 // Non-const and control inputs.
3608 nonconst_inputs.push_back(i);
3609 }
3610 }
3611 // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3612 // a fake node that cannot be constant folded by itself.
3613 int const_inputs_size = const_inputs.size();
3614 if (const_inputs_size == num_non_control_inputs &&
3615 node->op() == "AccumulateNV2") {
3616 node->set_op("AddN");
3617 node->mutable_attr()->erase("shape");
3618 return true;
3619 }
3620 const string new_node_name = OptimizedNodeName(
3621 *node, strings::StrCat("_partial_split_", const_inputs_size));
3622 if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3623 !node_map_->NodeExists(new_node_name)) {
3624 NodeDef* added_node = optimized_graph->add_node();
3625 *added_node = *node;
3626 // Always use AddN for the constant node, since AccumulateNV2 is a fake
3627 // node that cannot be constant folded, since it does not have a kernel.
3628 added_node->set_op("AddN");
3629 added_node->mutable_attr()->erase("shape");
3630 added_node->set_name(new_node_name);
3631 node_map_->AddNode(added_node->name(), added_node);
3632 added_node->clear_input();
3633 for (int i : const_inputs) {
3634 added_node->add_input(node->input(i));
3635 node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3636 added_node->name());
3637 }
3638
3639 // Overwrite the first const input with the added node.
3640 node->set_input(const_inputs[0], added_node->name());
3641 node_map_->AddOutput(added_node->name(), node->name());
3642 nonconst_inputs.push_back(const_inputs[0]);
3643 // Compact the remaining inputs to the original node.
3644 std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3645 int idx = 0;
3646 for (int i : nonconst_inputs) {
3647 if (idx != i) {
3648 node->set_input(idx, node->input(i));
3649 }
3650 ++idx;
3651 }
3652 node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3653 const_inputs.size() - 1);
3654 (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3655 properties->ClearInputProperties(node->name());
3656 (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3657 return true;
3658 }
3659 return false;
3660 }
3661
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3662 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3663 GraphProperties* properties,
3664 NodeDef* node) {
3665 // Partial constant folding for Concat which is not commutative, so
3666 // we have to preserve order and can only push consecutive runs of constant
3667 // inputs into sub-nodes.
3668 if (!IsConcat(*node) ||
3669 node->name().rfind("_partial_split_") != string::npos) {
3670 return false;
3671 }
3672 const int num_non_control_inputs = NumNonControlInputs(*node);
3673 if (num_non_control_inputs <= 3) return false;
3674 int axis_arg = -1;
3675 int begin = 0;
3676 int end = num_non_control_inputs;
3677 if (node->op() == "Concat") {
3678 begin = 1;
3679 axis_arg = 0;
3680 } else if (node->op() == "ConcatV2") {
3681 end = num_non_control_inputs - 1;
3682 axis_arg = num_non_control_inputs - 1;
3683 } else {
3684 return false;
3685 }
3686
3687 // We search for consecutive runs of constant inputs in the range
3688 // [begin:end[ and push then down into child nodes.
3689 std::vector<std::pair<int, int>> constant_input_runs;
3690 int first = begin;
3691 int last = begin;
3692 while (last < end) {
3693 while (first < end && !IsReallyConstant(*node_map_->GetNode(
3694 NodeName(node->input(first))))) {
3695 ++first;
3696 }
3697 // Invariant: node[first] is constant || first >= end.
3698 last = first + 1;
3699 while (last < end &&
3700 IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3701 ++last;
3702 }
3703 // Invariant: node[last] is not constant || last >= end
3704 // Discard intervals shorter than 2 elements.
3705 if (first < end && (last - first) > 1) {
3706 constant_input_runs.emplace_back(first, last);
3707 }
3708 first = last;
3709 }
3710
3711 // Skip if all inputs are constant, and let constant folding take over.
3712 if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3713 constant_input_runs[0].first == begin &&
3714 constant_input_runs[0].second == end)) {
3715 return false;
3716 }
3717 std::set<int> inputs_to_delete;
3718 for (auto interval : constant_input_runs) {
3719 // Push the constant inputs in the interval to a child node than can be
3720 // constant folded.
3721 string new_node_name = OptimizedNodeName(*node, "_partial_split");
3722 do {
3723 new_node_name += strings::StrCat("_", interval.first);
3724 } while (node_map_->NodeExists(new_node_name));
3725
3726 NodeDef* added_node = optimized_graph->add_node();
3727 *added_node = *node;
3728 added_node->set_op("ConcatV2");
3729 added_node->set_name(new_node_name);
3730 node_map_->AddNode(added_node->name(), added_node);
3731 added_node->clear_input();
3732 for (int i = interval.first; i < interval.second; ++i) {
3733 added_node->add_input(node->input(i));
3734 node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3735 if (i != interval.first) {
3736 inputs_to_delete.insert(i);
3737 }
3738 }
3739 added_node->add_input(node->input(axis_arg));
3740 (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3741 node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3742
3743 // Overwrite the first constant input with the result of the added
3744 // child node.
3745 node->set_input(interval.first, added_node->name());
3746 }
3747 if (!inputs_to_delete.empty()) {
3748 // Fix up the inputs to the original node.
3749 protobuf::RepeatedPtrField<string> tmp;
3750 tmp.Swap(node->mutable_input());
3751 for (int i = 0; i < tmp.size(); ++i) {
3752 if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3753 node->add_input(tmp.Get(i));
3754 }
3755 }
3756 (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3757 properties->ClearInputProperties(node->name());
3758 }
3759 return true;
3760 }
3761
GetConcatAxis(const NodeDef & node,int * axis)3762 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3763 if (node.op() != "ConcatV2") {
3764 return false;
3765 }
3766 int axis_idx = node.input_size() - 1;
3767 while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3768 --axis_idx;
3769 }
3770 if (axis_idx <= 0) {
3771 return false;
3772 }
3773 Tensor axis_tensor;
3774 if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3775 return false;
3776 }
3777 *axis = axis_tensor.dtype() == DT_INT64
3778 ? static_cast<int>(axis_tensor.scalar<int64>()())
3779 : axis_tensor.scalar<int32>()();
3780 return true;
3781 }
3782
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3783 bool ConstantFolding::MergeConcat(bool use_shape_info,
3784 GraphProperties* properties,
3785 GraphDef* optimized_graph, NodeDef* node) {
3786 // We only optimize for ConcatV2.
3787 int axis;
3788 if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3789 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3790 node_map_->GetOutputs(node->name()).size() != 1) {
3791 return false;
3792 }
3793
3794 // If all inputs are constant, don't merge and let folding take case of it.
3795 const int num_regular_inputs = NumNonControlInputs(*node);
3796 bool all_inputs_are_const = true;
3797 for (int i = 0; i < num_regular_inputs - 1; ++i) {
3798 const NodeDef* input_node = node_map_->GetNode(node->input(i));
3799 if (!IsReallyConstant(*input_node)) {
3800 all_inputs_are_const = false;
3801 break;
3802 }
3803 }
3804 if (all_inputs_are_const) return false;
3805
3806 NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3807 int parent_axis;
3808 if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3809 return false;
3810 }
3811
3812 // Make a pass over the parent inputs to see if any of them have explicit
3813 // device() fields set, and if different inputs are on different tasks. If
3814 // so, this concat of concats may have been carefully constructed to be a
3815 // two-stage concat, and we don't want to undo that here.
3816 string task, device;
3817 absl::flat_hash_set<string> unique_input_tasks;
3818 const int n_parent_inputs = NumNonControlInputs(*parent);
3819 // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1). The
3820 // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3821 // node, which we don't want to consider here.
3822 for (int i = 0; i < n_parent_inputs - 1; ++i) {
3823 const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3824 if (!input_node->device().empty() &&
3825 tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3826 &task, &device)) {
3827 unique_input_tasks.insert(task);
3828 if (unique_input_tasks.size() >= 2) {
3829 // More than one input task represented in the device specifications
3830 // of the parent's input nodes. Don't mess with this.
3831 return false;
3832 }
3833 }
3834 }
3835
3836 protobuf::RepeatedPtrField<string> parent_inputs;
3837 parent_inputs.Swap(parent->mutable_input());
3838 // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3839 // collapse it into the parent multiple times? Probably not.
3840 for (const auto& input : parent_inputs) {
3841 if (IsSameInput(input, node->name())) {
3842 for (int j = 0; j < num_regular_inputs - 1; ++j) {
3843 // Add tensor inputs to first child concat tensors (except the final
3844 // axis input) to the parent's inputs.
3845 parent->add_input(node->input(j));
3846 node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3847 }
3848 } else {
3849 parent->add_input(input);
3850 }
3851 }
3852 // Forward Add control inputs
3853 const int num_inputs = node->input_size();
3854 for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3855 parent->add_input(node->input(i));
3856 node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3857 node->mutable_input()->RemoveLast();
3858 }
3859 (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3860 DedupControlInputs(parent);
3861 ReplaceOperationWithNoOp(node, properties, optimized_graph);
3862
3863 return true;
3864 }
3865
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3866 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3867 NodeDef* node, GraphDef* optimized_graph) {
3868 auto add_quantized_out = [this, node, optimized_graph](
3869 const string& out_const_name, int index) {
3870 NodeDef* out_node = optimized_graph->add_node();
3871 graph_modified_ = true;
3872 Tensor value(DT_FLOAT, TensorShape({}));
3873 const bool is_min = index == 1;
3874 const DataType type_attr = node->attr().at("dtype").type();
3875
3876 value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3877 : QuantizedTypeMaxAsFloat(type_attr);
3878 TF_RETURN_IF_ERROR(
3879 CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3880 node_map_->AddNode(out_const_name, out_node);
3881 out_node->set_device(node->device());
3882 // Copy all inputs from node.
3883 out_node->mutable_input()->CopyFrom(node->input());
3884 for (const string& input : out_node->input()) {
3885 node_map_->AddOutput(NodeName(input), out_const_name);
3886 }
3887
3888 // Update output nodes consuming node:index to new const node.
3889 string old_input = absl::StrCat(node->name(), ":", index);
3890 int old_node_count = 0;
3891 // We make a copy since the set might change.
3892 auto outputs = node_map_->GetOutputs(node->name());
3893 for (const auto& output : outputs) {
3894 for (int i = 0; i < output->input_size(); ++i) {
3895 if (output->input(i) == old_input) {
3896 output->set_input(i, out_const_name);
3897 node_map_->AddOutput(out_const_name, output->name());
3898 } else if (NodeName(output->input(i)) == node->name()) {
3899 ++old_node_count;
3900 }
3901 }
3902 if (old_node_count == 0) {
3903 node_map_->RemoveOutput(node->name(), output->name());
3904 }
3905 }
3906
3907 return Status::OK();
3908 };
3909 const string min_out_const_name =
3910 OptimizedNodeName(*node, "-quantized_matmul_min_out");
3911 const string max_out_const_name =
3912 OptimizedNodeName(*node, "-quantized_matmul_max_out");
3913 if (node_map_->GetNode(min_out_const_name) == nullptr &&
3914 node_map_->GetNode(max_out_const_name) == nullptr) {
3915 TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3916 TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3917 } else {
3918 return errors::Internal(absl::Substitute(
3919 "Can't create Const for QuantizedMatMul min_out/max_out of "
3920 "node '$0' because of node name conflict",
3921 node->name()));
3922 }
3923 return Status::OK();
3924 }
3925
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphDef * optimized_graph)3926 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3927 GrapplerItem* item,
3928 GraphDef* optimized_graph) {
3929 graph_ = &item->graph;
3930 node_map_.reset(new NodeMap(graph_));
3931 nodes_allowlist_.clear();
3932 // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3933 // has a single fanout, it would be rewritten as a constant with the same
3934 // node name, and therefore users are still able to fetch it. This is not
3935 // the case if the node has multiple fanouts, and constant folding would
3936 // replace the node with multiple constants (each for one fanout) with
3937 // new names, and as a result users would not be able to fetch the node any
3938 // more with the original node name.
3939 for (const auto& fetch : item->fetch) {
3940 const NodeDef* fetch_node = node_map_->GetNode(fetch);
3941 if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3942 nodes_allowlist_.insert(fetch_node->name());
3943 }
3944 }
3945
3946 GraphProperties properties(*item);
3947 // It's possible to feed a placeholder with a tensor of any shape: make sure
3948 // that the shape inference deals with this conservatively unless we're in
3949 // aggressive mode.
3950 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3951 Status s = properties.InferStatically(assume_valid_feeds,
3952 /*aggressive_shape_inference=*/false,
3953 /*include_input_tensor_values=*/false,
3954 /*include_output_tensor_values=*/true);
3955
3956 const bool can_use_shape_info = s.ok();
3957 VLOG(1) << "can_use_shape_info = " << can_use_shape_info;
3958
3959 absl::flat_hash_set<string> nodes_to_not_simplify;
3960 if (can_use_shape_info) {
3961 TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3962 TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3963 TF_RETURN_IF_ERROR(
3964 FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
3965 } else {
3966 *optimized_graph = *graph_;
3967 }
3968 node_map_.reset(new NodeMap(optimized_graph));
3969 TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3970 &properties, &nodes_to_not_simplify));
3971
3972 return Status::OK();
3973 }
3974
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3975 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3976 GraphDef* optimized_graph) {
3977 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
3978 // the same here.
3979 port::ScopedFlushDenormal flush;
3980 port::ScopedSetRound round(FE_TONEAREST);
3981 nodes_to_preserve_ = item.NodesToPreserve();
3982 for (const auto& feed : item.feed) {
3983 feed_nodes_.insert(NodeName(feed.first));
3984 }
3985
3986 if (cpu_device_ == nullptr) {
3987 owned_device_.reset(new DeviceSimple());
3988 cpu_device_ = owned_device_.get();
3989 }
3990
3991 graph_contains_assign_or_inplace_op_ = false;
3992 for (const NodeDef& node : item.graph.node()) {
3993 if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
3994 graph_contains_assign_or_inplace_op_ = true;
3995 break;
3996 }
3997 }
3998
3999 has_fetch_ = !item.fetch.empty();
4000 GrapplerItem item_to_optimize = item;
4001 *optimized_graph = GraphDef();
4002 item_to_optimize.graph.Swap(optimized_graph);
4003 int64 node_count;
4004 do {
4005 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4006 graph_modified_ = false;
4007 item_to_optimize.graph.Swap(optimized_graph);
4008 optimized_graph->Clear();
4009 node_count = item_to_optimize.graph.node_size();
4010 TF_RETURN_IF_ERROR(
4011 RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
4012 } while (graph_modified_ || optimized_graph->node_size() != node_count);
4013 *optimized_graph->mutable_library() = item.graph.library();
4014 *optimized_graph->mutable_versions() = item.graph.versions();
4015
4016 return Status::OK();
4017 }
4018
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)4019 void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
4020 const GraphDef& optimize_output, double result) {
4021 // Nothing to do for ConstantFolding.
4022 }
4023
4024 } // namespace grappler
4025 } // namespace tensorflow
4026