1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19
20 #include <cmath>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/versions.pb.h"
35 #include "tensorflow/core/grappler/clusters/cluster.h"
36 #include "tensorflow/core/grappler/costs/graph_properties.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/op_types.h"
39 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
40 #include "tensorflow/core/grappler/utils.h"
41 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
42 #include "tensorflow/core/lib/core/errors.h"
43 #include "tensorflow/core/lib/core/stringpiece.h"
44 #include "tensorflow/core/lib/gtl/cleanup.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/lib/strings/numbers.h"
47 #include "tensorflow/core/lib/strings/strcat.h"
48 #include "tensorflow/core/platform/cpu_info.h"
49 #include "tensorflow/core/platform/denormal.h"
50 #include "tensorflow/core/platform/env.h"
51 #include "tensorflow/core/platform/setround.h"
52 #include "tensorflow/core/platform/tensor_coding.h"
53 #include "tensorflow/core/public/version.h"
54 #include "tensorflow/core/util/bcast.h"
55 #include "tensorflow/core/util/saved_tensor_slice_util.h"
56
57 namespace tensorflow {
58 namespace grappler {
59 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
60
61 namespace {
62 class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface {
63 public:
EigenThreadPoolWrapper(thread::ThreadPool * pool)64 explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {}
~EigenThreadPoolWrapper()65 ~EigenThreadPoolWrapper() override {}
Schedule(std::function<void ()> fn)66 void Schedule(std::function<void()> fn) override {
67 auto wrapped = [=]() {
68 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
69 // the same here.
70 port::ScopedFlushDenormal flush;
71 port::ScopedSetRound round(FE_TONEAREST);
72 fn();
73 };
74 pool_->Schedule(std::move(wrapped));
75 }
NumThreads() const76 int NumThreads() const override { return pool_->NumThreads(); }
CurrentThreadId() const77 int CurrentThreadId() const override { return pool_->CurrentThreadId(); }
78
79 private:
80 thread::ThreadPool* pool_ = nullptr;
81 };
82
83 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)84 bool AllValuesAre(const TensorProto& proto, const T& value) {
85 Tensor tensor;
86 if (!tensor.FromProto(proto)) {
87 return false;
88 }
89 auto values = tensor.flat<T>();
90 for (int i = 0; i < tensor.NumElements(); ++i) {
91 if (values(i) != value) {
92 return false;
93 }
94 }
95 return true;
96 }
97
98 // Add new_input as a control input to node if it does not already depend on it.
99 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
100 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)101 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
102 GraphDef* graph, NodeMap* node_map) {
103 bool already_exists = false;
104 for (const string& input : node->input()) {
105 if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
106 already_exists = true;
107 break;
108 }
109 }
110 if (!already_exists) {
111 const string ctrl_dep =
112 ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
113 node->add_input(ctrl_dep);
114 node_map->AddOutput(NodeName(ctrl_input), node->name());
115 }
116 return !already_exists;
117 }
118
119 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)120 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
121 GraphDef* graph, NodeMap* node_map) {
122 bool removed_input = false;
123 bool update_node_map = true;
124 const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
125 for (int i = 0; i < node->input_size(); ++i) {
126 const string& input = node->input(i);
127 if (old_input_ctrl_dep == input) {
128 if (IsControlInput(input)) {
129 node->mutable_input()->SwapElements(i, node->input_size() - 1);
130 node->mutable_input()->RemoveLast();
131 removed_input = true;
132 } else {
133 // There is a non-control input from the same node.
134 // Don't remove the output from the NodeMap.
135 update_node_map = false;
136 }
137 }
138 }
139 if (update_node_map) {
140 node_map->RemoveOutput(NodeName(old_input), node->name());
141 }
142 return removed_input;
143 }
144
GetConcatAxis(const GraphProperties & properties,NodeDef * node,int * axis)145 bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
146 int* axis) {
147 if (node->op() != "ConcatV2" ||
148 properties.GetInputProperties(node->name()).empty()) {
149 return false;
150 }
151 const auto& axis_input = properties.GetInputProperties(node->name()).back();
152 if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
153 return false;
154 }
155
156 Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
157 if (!axis_tensor.FromProto(axis_input.value())) {
158 return false;
159 }
160 *axis = axis_input.dtype() == DT_INT64
161 ? static_cast<int>(axis_tensor.scalar<int64>()())
162 : axis_tensor.scalar<int32>()();
163 return true;
164 }
165
HasTPUAttributes(const NodeDef & node)166 bool HasTPUAttributes(const NodeDef& node) {
167 AttrSlice attrs(node);
168 for (auto attr : attrs) {
169 if (attr.first.find("_tpu_") != attr.first.npos) {
170 return true;
171 }
172 }
173 return false;
174 }
175
176 template <typename T>
PackedValuesNotEqual(T a,T b)177 bool PackedValuesNotEqual(T a, T b) {
178 return a != b;
179 }
180
181 template <>
PackedValuesNotEqual(float a,float b)182 bool PackedValuesNotEqual(float a, float b) {
183 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
184 }
185
186 template <>
PackedValuesNotEqual(double a,double b)187 bool PackedValuesNotEqual(double a, double b) {
188 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
189 }
190
QuantizedTypeMinAsFloat(DataType data_type)191 float QuantizedTypeMinAsFloat(DataType data_type) {
192 switch (data_type) {
193 case DT_QINT8:
194 return Eigen::NumTraits<qint8>::lowest();
195 case DT_QUINT8:
196 return Eigen::NumTraits<quint8>::lowest();
197 case DT_QINT16:
198 return Eigen::NumTraits<qint16>::lowest();
199 case DT_QUINT16:
200 return Eigen::NumTraits<quint16>::lowest();
201 case DT_QINT32:
202 return Eigen::NumTraits<qint32>::lowest();
203 default:
204 return 0.0f;
205 }
206 }
207
QuantizedTypeMaxAsFloat(DataType data_type)208 float QuantizedTypeMaxAsFloat(DataType data_type) {
209 switch (data_type) {
210 case DT_QINT8:
211 return Eigen::NumTraits<qint8>::highest();
212 case DT_QUINT8:
213 return Eigen::NumTraits<quint8>::highest();
214 case DT_QINT16:
215 return Eigen::NumTraits<qint16>::highest();
216 case DT_QUINT16:
217 return Eigen::NumTraits<quint16>::highest();
218 case DT_QINT32:
219 return Eigen::NumTraits<qint32>::highest();
220 default:
221 return 0.0f;
222 }
223 }
224
225 } // namespace
226
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device)227 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
228 DeviceBase* cpu_device)
229 : opt_level_(opt_level), cpu_device_(cpu_device) {
230 resource_mgr_.reset(new ResourceMgr());
231 }
232
ConstantFolding(DeviceBase * cpu_device)233 ConstantFolding::ConstantFolding(DeviceBase* cpu_device)
234 : ConstantFolding(RewriterConfig::ON, cpu_device) {}
235
236 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)237 string ConstantFolding::AddControlDependency(const string& input_name,
238 GraphDef* graph,
239 NodeMap* node_map) {
240 if (IsControlInput(input_name)) {
241 return input_name;
242 }
243 const NodeDef* node = node_map->GetNode(input_name);
244 if (!IsSwitch(*node)) {
245 return AsControlDependency(*node);
246 } else {
247 // We can't anchor control dependencies directly on the switch node: unlike
248 // other nodes only one of the outputs of the switch node will be generated
249 // when the switch node is executed, and we need to make sure the control
250 // dependency is only triggered when the corresponding output is triggered.
251 // We start by looking for an identity node connected to the output of the
252 // switch node, and use it to anchor the control dependency.
253 auto outputs = node_map->GetOutputs(node->name());
254 for (const NodeDef* output : outputs) {
255 if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
256 if (IsSameInput(node->input(0), input_name)) {
257 return AsControlDependency(*output);
258 }
259 }
260 }
261 // We haven't found an existing node where we can anchor the control
262 // dependency: add a new identity node.
263 int port = 0;
264 string ctrl_dep_name = ParseNodeName(input_name, &port);
265 strings::StrAppend(&ctrl_dep_name, "_", port);
266 ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
267 const DataType output_type = node->attr().at("T").type();
268
269 NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
270 if (added_node == nullptr) {
271 added_node = graph->add_node();
272 added_node->set_name(ctrl_dep_name);
273 added_node->set_op("Identity");
274 added_node->set_device(node->device());
275
276 (*added_node->mutable_attr())["T"].set_type(output_type);
277 *added_node->add_input() = input_name;
278 node_map->AddNode(added_node->name(), added_node);
279 node_map->AddOutput(node->name(), added_node->name());
280 }
281 return AsControlDependency(*added_node);
282 }
283 }
284
285 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64 value,const DataType & type,const int index,Tensor * tensor)286 static Status PutValueIntoTensor(const int64 value, const DataType& type,
287 const int index, Tensor* tensor) {
288 if (type == DT_INT32) {
289 if (value >= INT_MAX) {
290 return Status(error::INVALID_ARGUMENT, "int32 overflow");
291 }
292 tensor->flat<int32>()(index) = static_cast<int32>(value);
293 } else {
294 tensor->flat<int64>()(index) = value;
295 }
296 return Status::OK();
297 }
298
299 // Writes the given tensor shape into the given tensor.
300 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)301 static Status ConvertShapeToConstant(const string& op, const DataType& type,
302 const PartialTensorShape& shp,
303 Tensor* tensor) {
304 if (op == "Shape" || op == "ShapeN") {
305 *tensor = Tensor(type, TensorShape({shp.dims()}));
306 for (int i = 0; i < shp.dims(); ++i) {
307 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
308 }
309 } else if (op == "Size") {
310 int64 size = 1;
311 for (int i = 0; i < shp.dims(); ++i) {
312 size *= shp.dim_size(i);
313 }
314 *tensor = Tensor(type, TensorShape({}));
315 TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
316 } else {
317 CHECK_EQ(op, "Rank");
318 *tensor = Tensor(type, TensorShape({}));
319 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
320 }
321 return Status::OK();
322 }
323
324 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const325 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
326 StringPiece suffix) const {
327 return node_map_->NodeExists(OptimizedNodeName(node, suffix));
328 }
329
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const330 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
331 StringPiece suffix) const {
332 return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
333 kConstantFoldingConst);
334 }
335
IsReallyConstant(const NodeDef & node) const336 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
337 if (!IsConstant(node)) {
338 return false;
339 }
340 // If the node is fed it's not constant anymore.
341 return feed_nodes_.find(node.name()) == feed_nodes_.end();
342 }
343
344 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)345 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
346 // We may add some nodes to the graph to encode control dependencies and hold
347 // the materialized shapes: there is no need to process these added nodes, so
348 // only iterate over the nodes of the input graph.
349 const int node_count = graph_->node_size();
350 for (int node_idx = 0; node_idx < node_count; ++node_idx) {
351 NodeDef* node = graph_->mutable_node(node_idx);
352 const string op = node->op();
353 if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
354 op != "TensorArraySizeV3") {
355 continue;
356 }
357
358 const std::vector<OpInfo::TensorProperties>& output =
359 properties.GetOutputProperties(node->name());
360 const std::vector<OpInfo::TensorProperties>& input =
361 properties.GetInputProperties(node->name());
362 if (input.empty() || output.empty()) {
363 continue;
364 }
365
366 if (op == "Shape" || op == "Size" || op == "Rank") {
367 CHECK_EQ(1, output.size());
368 CHECK_EQ(1, input.size());
369
370 const DataType type = output[0].dtype();
371 CHECK(type == DT_INT32 || type == DT_INT64);
372 const PartialTensorShape shape(input[0].shape());
373
374 if ((op != "Rank" && !shape.IsFullyDefined()) ||
375 (op == "Rank" && shape.unknown_rank())) {
376 continue;
377 }
378
379 Tensor constant_value(type);
380 if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
381 continue;
382 }
383
384 // Repurpose the existing node to be the constant.
385 // Device placement is preserved.
386 node->set_op("Const");
387 node->clear_attr();
388 (*node->mutable_attr())["dtype"].set_type(type);
389 constant_value.AsProtoTensorContent(
390 (*node->mutable_attr())["value"].mutable_tensor());
391
392 // Turn the data input into a control dependency: this is needed to
393 // ensure that the constant value will only be run in the
394 // cases where the shape/rank/size would have been run in
395 // the original graph.
396 string ctrl_dep =
397 AddControlDependency(node->input(0), graph_, node_map_.get());
398 node->set_input(0, ctrl_dep);
399 node_map_->AddOutput(NodeName(ctrl_dep), node->name());
400
401 // Done with the Shape/Size/Rank node, move to the next node.
402 continue;
403 }
404
405 if (op == "TensorArraySizeV3") {
406 const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
407 if (array->input_size() == 0 ||
408 (array->attr().count("dynamic_size") != 0 &&
409 array->attr().at("dynamic_size").b())) {
410 continue;
411 }
412 const NodeDef* array_size =
413 CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
414 if (IsReallyConstant(*array_size)) {
415 // Don't materialize 0 sizes to avoid triggering incorrect static
416 // checks. A 0 sized array that can't grow isn't useful anyway.
417 if (array_size->attr().count("value") == 0) {
418 continue;
419 }
420 const TensorProto& raw_val = array_size->attr().at("value").tensor();
421 if (raw_val.dtype() != DT_INT32) {
422 continue;
423 }
424 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
425 if (!value.FromProto(raw_val)) {
426 continue;
427 }
428 if (value.flat<int32>()(0) == 0) {
429 continue;
430 }
431
432 node->set_op("Const");
433 *node->mutable_attr() = array_size->attr();
434 node->set_input(0, AsControlDependency(NodeName(node->input(0))));
435 node->set_input(1, AddControlDependency(NodeName(node->input(1)),
436 graph_, node_map_.get()));
437 }
438 continue;
439 }
440
441 // Handle ShapeN materialization case.
442 // It's possible that not all input tensors have known shapes.
443 CHECK_EQ(op, "ShapeN");
444 CHECK_EQ(input.size(), output.size());
445 const NodeDef* const shape_n_node = node;
446 for (int port_idx = 0; port_idx < output.size(); ++port_idx) {
447 const DataType type = output[port_idx].dtype();
448 CHECK(type == DT_INT32 || type == DT_INT64);
449 const PartialTensorShape shape(input[port_idx].shape());
450 if (!shape.IsFullyDefined()) {
451 continue;
452 }
453 Tensor constant_value(type);
454 auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
455 if (!status.ok()) {
456 continue;
457 }
458
459 // Find all nodes consuming this shape and connect them through the new
460 // constant node instead.
461 auto outputs = node_map_->GetOutputs(shape_n_node->name());
462 for (NodeDef* output : outputs) {
463 // Track whether there are any direct edges left between shape_n_node
464 // and this output node after the transformation.
465 bool direct_edges_exist = false;
466 for (int k = 0; k < output->input_size(); ++k) {
467 int port;
468 const string node_name = ParseNodeName(output->input(k), &port);
469 if (node_name == shape_n_node->name() && port == port_idx) {
470 // Create a const node as ShapeN's output if not already.
471 const string const_name = OptimizedNodeName(
472 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
473 if (node_map_->GetNode(const_name) == nullptr) {
474 NodeDef* added_node = graph_->add_node();
475 added_node->set_name(const_name);
476 added_node->set_op("Const");
477 added_node->set_device(shape_n_node->device());
478 node_map_->AddNode(added_node->name(), added_node);
479 (*added_node->mutable_attr())["dtype"].set_type(type);
480 constant_value.AsProtoTensorContent(
481 (*added_node->mutable_attr())["value"].mutable_tensor());
482 // We add a control dependency to the original ShapeN node,
483 // so that the node will only be run if all inputs of the
484 // original ShapeN node are run.
485 string ctrl_dep = AddControlDependency(shape_n_node->name(),
486 graph_, node_map_.get());
487 *added_node->add_input() = ctrl_dep;
488 node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
489 }
490 *output->mutable_input(k) = const_name;
491 node_map_->AddOutput(const_name, output->name());
492 }
493 if (node_name == shape_n_node->name() && port != port_idx) {
494 direct_edges_exist = true;
495 }
496 }
497 if (!direct_edges_exist) {
498 node_map_->RemoveOutput(node->name(), output->name());
499 }
500 }
501 }
502 }
503
504 return Status::OK();
505 }
506
507 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64 * min_id)508 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
509 BCast::Vec* shape, int64* min_id) {
510 if (shape_node.op() == "Shape") {
511 const std::vector<OpInfo::TensorProperties>& prop1 =
512 properties.GetInputProperties(shape_node.name());
513 if (prop1.size() != 1) {
514 return false;
515 }
516 const TensorShapeProto& shp = prop1[0].shape();
517 if (shp.unknown_rank()) {
518 return false;
519 }
520 for (const auto& dim : shp.dim()) {
521 shape->push_back(dim.size());
522 *min_id = std::min<int64>(*min_id, dim.size());
523 }
524 } else {
525 if (shape_node.attr().count("value") == 0) {
526 return false;
527 }
528 const TensorProto& raw_val = shape_node.attr().at("value").tensor();
529 if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
530 return false;
531 }
532 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
533 if (!value.FromProto(raw_val)) {
534 return false;
535 }
536 for (int j = 0; j < value.NumElements(); ++j) {
537 if (raw_val.dtype() == DT_INT64) {
538 shape->push_back(value.vec<int64>()(j));
539 } else {
540 shape->push_back(value.vec<int>()(j));
541 }
542 }
543 }
544 return true;
545 }
546 } // namespace
547
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)548 Status ConstantFolding::MaterializeBroadcastGradientArgs(
549 const NodeDef& node, const GraphProperties& properties) {
550 const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
551 const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
552 if (shape_node1 == nullptr ||
553 (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
554 shape_node2 == nullptr ||
555 (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
556 return Status::OK();
557 }
558
559 // Don't optimize this again if it was already optimized and folded.
560 if (OptimizedNodeExists(node, "-folded-1") ||
561 OptimizedNodeExists(node, "-folded-2")) {
562 return Status::OK();
563 }
564 int64 min_id = 0;
565 BCast::Vec shape1;
566 if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
567 return Status::OK();
568 }
569 BCast::Vec shape2;
570 if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
571 return Status::OK();
572 }
573 // A value of -1 means we don't known anything about the dimension. Replace
574 // the -1 values with unique dimension ids since we don't want two '-1'
575 // dimensions to be considered equal.
576 for (auto& id : shape1) {
577 if (id == -1) {
578 id = --min_id;
579 }
580 }
581 for (auto& id : shape2) {
582 if (id == -1) {
583 id = --min_id;
584 }
585 }
586
587 // Beware: the reduction dimensions computed by the BCast class are valid iff
588 // we assume that two distinct symbolic dimensions can't be equal and a
589 // symbolic dimension can't be equal to 1. This is often but not always true,
590 // so to make this optimization safe we filter out these cases.
591 const int common_dims = std::min(shape1.size(), shape2.size());
592 for (int i = 0; i < common_dims; ++i) {
593 if (shape1[i] >= 0 && shape2[i] >= 0) {
594 continue;
595 }
596 if (shape1[i] != shape2[i]) {
597 // We're either dealing with 2 different symbolic dimensions or a symbolic
598 // and a know dimensions. We can't be sure whether both are equal or not,
599 // so we can't be sure whether we'll be broadcasting or not.
600 return Status::OK();
601 }
602 }
603 // These extra dims could be equal to 1, in which case there is no
604 // broadcasting. It could also be greater than 1, in which case there would
605 // be broadcasting. Since we don't know, we'll just punt.
606 for (int i = common_dims; i < shape1.size(); ++i) {
607 if (shape1[i] < 0) {
608 return Status::OK();
609 }
610 }
611 for (int i = common_dims; i < shape2.size(); ++i) {
612 if (shape2[i] < 0) {
613 return Status::OK();
614 }
615 }
616
617 BCast bcast(shape1, shape2);
618 if (!bcast.IsValid()) {
619 return Status::OK();
620 }
621
622 BCast::Vec reduce_dims[2];
623 reduce_dims[0] = bcast.grad_x_reduce_idx();
624 reduce_dims[1] = bcast.grad_y_reduce_idx();
625
626 TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
627 const DataType type = node.attr().at("T").type();
628 NodeDef* out[2];
629 for (int j = 0; j < 2; ++j) {
630 int reduction_indices = reduce_dims[j].size();
631 Tensor value(type, TensorShape({reduction_indices}));
632 for (int i = 0; i < reduction_indices; ++i) {
633 if (type == DT_INT32) {
634 value.vec<int32>()(i) = reduce_dims[j][i];
635 } else {
636 value.vec<int64>()(i) = reduce_dims[j][i];
637 }
638 }
639 string const_name =
640 OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
641 out[j] = node_map_->GetNode(const_name);
642 if (out[j] == nullptr) {
643 out[j] = graph_->add_node();
644 TF_RETURN_IF_ERROR(
645 CreateNodeDef(const_name, TensorValue(&value), out[j]));
646 out[j]->set_device(node.device());
647 node_map_->AddNode(const_name, out[j]);
648 string ctrl_dep =
649 AddControlDependency(node.name(), graph_, node_map_.get());
650 *out[j]->add_input() = ctrl_dep;
651 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
652 }
653 }
654
655 const std::set<NodeDef*> outputs = node_map_->GetOutputs(node.name());
656 for (NodeDef* output : outputs) {
657 for (int k = 0; k < output->input_size(); ++k) {
658 int port;
659 string node_name = ParseNodeName(output->input(k), &port);
660 if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
661 *output->mutable_input(k) = out[port]->name();
662 node_map_->UpdateInput(output->name(), node_name, out[port]->name());
663 }
664 }
665 }
666
667 return Status::OK();
668 }
669
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)670 Status ConstantFolding::MaterializeReductionIndices(
671 NodeDef* node, const GraphProperties& properties) {
672 if (node->input_size() < 2) {
673 return Status::OK();
674 }
675 const NodeDef* indices = node_map_->GetNode(node->input(1));
676 if (!indices || IsReallyConstant(*indices)) {
677 // The reduction indices are already constant, there's nothing to do.
678 return Status::OK();
679 }
680
681 const std::vector<OpInfo::TensorProperties>& input_props =
682 properties.GetInputProperties(node->name());
683 if (input_props.size() != 2) {
684 return Status::OK();
685 }
686 const OpInfo::TensorProperties& input_prop = input_props[0];
687 if (input_prop.shape().unknown_rank()) {
688 // We can't do anything if we don't know the rank of the input.
689 return Status::OK();
690 }
691 const int input_rank = input_prop.shape().dim_size();
692 if (input_rank < 1) {
693 // Unexpected graph, don't try to change it.
694 return Status::OK();
695 }
696 const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
697 DataType dtype = reduction_indices_prop.dtype();
698 if (dtype != DT_INT32 && dtype != DT_INT64) {
699 return Status::OK();
700 }
701 PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
702 const int num_reduction_indices = reduction_indices_shape.num_elements();
703
704 const std::vector<OpInfo::TensorProperties>& output_props =
705 properties.GetOutputProperties(node->name());
706 if (output_props.size() != 1) {
707 return Status::OK();
708 }
709 const OpInfo::TensorProperties& output_prop = output_props[0];
710 const int output_rank =
711 output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
712
713 bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
714 if (!full_reduction) {
715 // A full reduction will generate a tensor of one of the shapes
716 // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
717 // elements in the output of the reduction, we may deduce it from reshape
718 // nodes following it.
719 for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
720 full_reduction = false;
721 if (!IsReshape(*fanout)) {
722 return Status::OK();
723 }
724 const std::vector<OpInfo::TensorProperties>& reshape_props =
725 properties.GetOutputProperties(fanout->name());
726 if (reshape_props.size() != 1) {
727 return Status::OK();
728 }
729 const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
730 PartialTensorShape shape(reshape_prop.shape());
731 if (shape.num_elements() != 1) {
732 return Status::OK();
733 } else {
734 full_reduction = true;
735 }
736 }
737 if (!full_reduction) {
738 return Status::OK();
739 }
740 }
741
742 // We know it's a full reduction. We can generate the full set of indices to
743 // reduce as a constant node.
744 string const_name = OptimizedNodeName(*node, "-reduction_indices");
745 if (node_map_->GetNode(const_name)) {
746 return Status::OK();
747 }
748 NodeDef* reduction_indices = graph_->add_node();
749 Tensor value(dtype, TensorShape({input_rank}));
750 for (int i = 0; i < input_rank; ++i) {
751 if (dtype == DT_INT32) {
752 value.vec<int32>()(i) = i;
753 } else {
754 value.vec<int64>()(i) = i;
755 }
756 }
757 TF_RETURN_IF_ERROR(
758 CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
759
760 reduction_indices->set_device(node->device());
761 string ctrl_dep =
762 AddControlDependency(node->input(1), graph_, node_map_.get());
763 *reduction_indices->add_input() = ctrl_dep;
764 node_map_->AddNode(const_name, reduction_indices);
765 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
766
767 node->set_input(1, reduction_indices->name());
768 node_map_->UpdateInput(node->name(), indices->name(),
769 reduction_indices->name());
770
771 return Status::OK();
772 }
773
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)774 Status ConstantFolding::MaterializeConstantValuedNode(
775 NodeDef* node, const GraphProperties& properties) {
776 // Nodes that generate constant-valued outputs can be represented compactly in
777 // compressed format, regardless of their shape.
778 const std::vector<OpInfo::TensorProperties>& output_props =
779 properties.GetOutputProperties(node->name());
780 if (output_props.size() != 1) return Status::OK();
781 const auto& output_shape = output_props[0].shape();
782 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
783 return Status::OK();
784 }
785 if (IsFill(*node)) {
786 const auto output_dtype = output_props[0].dtype();
787 NodeDef* input_node = nullptr;
788 for (int i = 0; i < 2; ++i) {
789 input_node = node_map_->GetNode(NodeName(node->input(i)));
790 if (input_node == nullptr || !IsReallyConstant(*input_node)) {
791 return Status::OK();
792 }
793 }
794 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
795
796 // Copy the input tensor to the fill node, set the output shape and data
797 // type, and change the node type to Const.
798 TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
799 const TensorProto& input_tensor = input_node->attr().at("value").tensor();
800 if (!input_tensor.tensor_content().empty()) {
801 // Convert the value to repeated field format, so we can use the
802 // decompression mechanism to store only a single value in the constant
803 // node, even if the shape specified in the original Fill is large.
804 Tensor t;
805 if (!t.FromProto(input_tensor)) {
806 return errors::InvalidArgument(
807 "Could not construct Tensor form TensorProto in node: ",
808 input_node->name());
809 }
810 tensor->clear_tensor_content();
811 t.AsProtoField(tensor);
812 } else {
813 *tensor = input_tensor;
814 }
815 *(tensor->mutable_tensor_shape()) = output_shape;
816 (*node->mutable_attr())["dtype"].set_type(output_dtype);
817 node->mutable_attr()->erase("T");
818 node->mutable_attr()->erase("index_type");
819 node->set_op("Const");
820 for (int i = 0; i < 2; i++) {
821 // Change inputs to a control inputs.
822 const string ctrl_dep = AsControlDependency(node->input(i));
823 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
824 node->set_input(i, ctrl_dep);
825 }
826 graph_modified_ = true;
827 } else {
828 double value =
829 (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
830 bool success = false;
831 if (value >= 0) {
832 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
833 value, properties, output_shape, node, graph_, &success));
834 }
835 }
836 return Status::OK();
837 }
838
MaterializeConstants(const GraphProperties & properties)839 Status ConstantFolding::MaterializeConstants(
840 const GraphProperties& properties) {
841 const int node_count = graph_->node_size();
842 for (int i = 0; i < node_count; ++i) {
843 NodeDef& node = *graph_->mutable_node(i);
844 const string& op = node.op();
845 if (op == "BroadcastGradientArgs") {
846 TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
847 } else if (IsReduction(node)) {
848 TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
849 } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
850 TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
851 }
852 }
853 return Status::OK();
854 }
855
IsFoldable(const NodeDef & node) const856 bool ConstantFolding::IsFoldable(const NodeDef& node) const {
857 // Folding not applicable to ops with no inputs.
858 if (node.input().empty()) {
859 return false;
860 }
861 // Skips nodes that must be preserved except whitelisted nodes.
862 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
863 nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
864 return false;
865 }
866 // `FakeParam` op is used as a placeholder in If branch function. It doesn't
867 // have a valid output when executed.
868 if (IsFakeParam(node)) {
869 return false;
870 }
871
872 // Skip control flow nodes, they can't be folded.
873 if (ModifiesFrameInfo(node)) {
874 return false;
875 }
876
877 // Removing LoopCond nodes can screw up the partitioner.
878 if (node.op() == "LoopCond") {
879 return false;
880 }
881
882 // Skip constants, they're already folded
883 if (IsConstant(node)) {
884 return false;
885 }
886
887 // Don't fold stateful ops such as TruncatedNormal.
888 if (!IsFreeOfSideEffect(node)) {
889 return false;
890 }
891
892 // Skips ops that don't benefit from folding.
893 if (IsPlaceholder(node)) {
894 return false;
895 }
896 const string& op = node.op();
897 if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
898 op.find("Reader") != string::npos) {
899 return false;
900 }
901 if (op.find("Quantized") != string::npos || op.find("Sparse") == 0) {
902 return false;
903 }
904
905 // Don't fold nodes that contain TPU attributes.
906 // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
907 // properly forward custom attributes, b/119051778.
908 if (HasTPUAttributes(node)) {
909 return false;
910 }
911
912 const OpDef* op_def = nullptr;
913 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
914 if (!status.ok()) {
915 return false;
916 }
917 // Don't fold ops without outputs.
918 if (op_def->output_arg_size() == 0) {
919 return false;
920 }
921
922 // No need to (and don't) fold nodes that have no outgoing edges except
923 // whitelisted nodes. Such nodes could be introduced by an earlier constant
924 // folding pass and are preserved in case users want to fetch their values;
925 // re-processing them would lead to an error of adding a duplicated node
926 // to graph.
927 auto outputs = node_map_->GetOutputs(node.name());
928 if (outputs.empty() &&
929 nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
930 return false;
931 }
932
933 // We can only fold nodes if all their inputs are known statically, except in
934 // the case of a merge node that propagate the first inputs that becomes
935 // available, and therefore only requires a single constant input to be
936 // foldable.
937 bool merge_has_constant_input = false;
938 const bool is_merge = IsMerge(node);
939 for (const auto& input : node.input()) {
940 if (IsControlInput(input)) {
941 continue;
942 }
943 const NodeDef* input_node = node_map_->GetNode(input);
944 if (!input_node) {
945 return false;
946 }
947 bool is_const = IsReallyConstant(*input_node);
948 if (is_const) {
949 // Don't fold strings constants for now since this causes problems with
950 // checkpointing.
951 if (input_node->attr().count("dtype") == 0 ||
952 input_node->attr().at("dtype").type() == DT_STRING) {
953 return false;
954 }
955 // Special case: If a Merge node has at least one constant input that
956 // does not depend on a control input, we can fold it.
957 merge_has_constant_input |= !HasControlInputs(*input_node);
958 } else if (!is_merge) {
959 return false;
960 }
961 }
962 return !is_merge || merge_has_constant_input;
963 }
964
965 namespace {
966
967 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
968 case DTYPE: \
969 t->add_##NAME##_val(static_cast<TYPE>(value)); \
970 break;
971
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)972 Status CreateConstantTensorAttrValue(DataType type, double value,
973 const TensorShapeProto& shape,
974 AttrValue* attr_tensor) {
975 TensorProto* t = attr_tensor->mutable_tensor();
976 t->set_dtype(type);
977 *t->mutable_tensor_shape() = shape;
978 switch (type) {
979 case DT_HALF:
980 t->add_half_val(static_cast<Eigen::half>(value).x);
981 break;
982 case DT_BFLOAT16:
983 t->add_half_val(static_cast<bfloat16>(value).value);
984 break;
985 SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
986 SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
987 SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
988 SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64);
989 SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
990 SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
991 SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
992 SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
993 SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
994 SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
995 SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
996 SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
997 SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
998 SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
999 SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1000 SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1001 default:
1002 return errors::InvalidArgument("Unsupported type: ", type);
1003 }
1004 return Status::OK();
1005 }
1006
1007 #undef SET_TENSOR_CAL_CASE
1008
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1009 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1010 const GraphProperties& properties) {
1011 DataType dtype = DT_INVALID;
1012 if (node.attr().count("T") == 1) {
1013 dtype = node.attr().at("T").type();
1014 } else if (node.attr().count("dtype") == 1) {
1015 dtype = node.attr().at("dtype").type();
1016 } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1017 dtype = DT_BOOL;
1018 } else {
1019 auto output_props = properties.GetOutputProperties(node.name());
1020 if (!output_props.empty()) {
1021 dtype = output_props[0].dtype();
1022 }
1023 }
1024 return dtype;
1025 }
1026
1027 // Checks whether the shape of the const input of the Mul op is valid to perform
1028 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1029 bool IsValidConstShapeForMulConvPushDown(
1030 const string& data_format, const TensorShapeProto& filter_shape,
1031 const TensorShapeProto& mul_const_input_shape) {
1032 // If the const is a scalar, or it has fewer or same number of dimensions
1033 // than the filter and it only has single element, the optimization should
1034 // work.
1035 if (mul_const_input_shape.dim_size() <= data_format.size() &&
1036 TensorShape(mul_const_input_shape).num_elements() == 1) {
1037 return true;
1038 }
1039
1040 // Otherwise, check the eligibility according to data format.
1041 if (data_format == "NHWC" || data_format == "NDHWC") {
1042 TensorShapeProto new_filter_shape;
1043 if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1044 &new_filter_shape)) {
1045 return false;
1046 }
1047 if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1048 return false;
1049 }
1050 // Only the last dimension could be larger than one, since broadcasting over
1051 // the last dimension (the output channel) will result in invalid filter.
1052 for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1053 if (mul_const_input_shape.dim(i).size() > 1) return false;
1054 }
1055 return true;
1056 } else if (data_format == "NCHW" || data_format == "NCDHW") {
1057 // TODO(laigd): support NCHW and NCDHW (b/111214513).
1058 return false;
1059 }
1060 return false;
1061 }
1062
1063 } // namespace
1064
1065 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1066 Status ConstantFolding::CreateNodeDef(const string& name,
1067 const TensorValue& tensor, NodeDef* node,
1068 size_t original_size) {
1069 node->set_name(name);
1070 node->set_op("Const");
1071
1072 AttrValue attr_type;
1073 attr_type.set_type(tensor->dtype());
1074 node->mutable_attr()->insert({"dtype", attr_type});
1075
1076 AttrValue attr_tensor;
1077 TensorProto* t = attr_tensor.mutable_tensor();
1078 bool optimized = false;
1079 size_t encoded_size;
1080 // Use the packed representation whenever possible to avoid generating large
1081 // graphdefs. Moreover, avoid repeating the last values if they're equal.
1082 if (tensor->NumElements() > 4) {
1083 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \
1084 { \
1085 const auto* val_ptr = tensor->flat<TYPE>().data(); \
1086 auto last = *val_ptr; \
1087 int64 last_index = 0; \
1088 for (int64 i = 0; i < tensor->NumElements(); ++i) { \
1089 TYPE cur = *val_ptr++; \
1090 if (PackedValuesNotEqual(cur, last)) { \
1091 last = cur; \
1092 last_index = i; \
1093 } \
1094 } \
1095 if (last_index < kint32max) { \
1096 optimized = true; \
1097 encoded_size = (last_index + 1) * sizeof(NAME); \
1098 t->mutable_##NAME##_val()->Reserve(last_index + 1); \
1099 const auto* src_ptr = tensor->flat<TYPE>().data(); \
1100 auto* dst_ptr = \
1101 t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \
1102 std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr); \
1103 } \
1104 } \
1105 break
1106
1107 switch (tensor->dtype()) {
1108 case DT_FLOAT:
1109 POPULATE_TENSOR_PROTO(tensor, t, float, float);
1110 case DT_DOUBLE:
1111 POPULATE_TENSOR_PROTO(tensor, t, double, double);
1112 case DT_INT64:
1113 POPULATE_TENSOR_PROTO(tensor, t, int64, int64);
1114 case DT_UINT64:
1115 POPULATE_TENSOR_PROTO(tensor, t, uint64, int64);
1116 case DT_INT32:
1117 POPULATE_TENSOR_PROTO(tensor, t, int32, int);
1118 case DT_UINT32:
1119 POPULATE_TENSOR_PROTO(tensor, t, uint32, int);
1120 case DT_INT16:
1121 POPULATE_TENSOR_PROTO(tensor, t, int16, int);
1122 case DT_UINT16:
1123 POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1124 case DT_INT8:
1125 POPULATE_TENSOR_PROTO(tensor, t, int8, int);
1126 case DT_UINT8:
1127 POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1128 case DT_BOOL:
1129 POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1130 default:
1131 /* Do nothing. */
1132 break;
1133 }
1134 }
1135 if (optimized) {
1136 // Also specify type and shape.
1137 t->set_dtype(tensor->dtype());
1138 tensor->shape().AsProto(t->mutable_tensor_shape());
1139 } else {
1140 // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1141 // DT_QUINT8
1142 tensor->AsProtoTensorContent(t);
1143 encoded_size = t->tensor_content().size();
1144 }
1145 node->mutable_attr()->insert({"value", attr_tensor});
1146
1147 if (encoded_size > original_size && encoded_size >= 10 * 1024 * 1024) {
1148 return errors::InvalidArgument(
1149 strings::StrCat("Can't fold ", name, ", its size would be too large (",
1150 encoded_size, " >= ", 10 * 1024 * 1024, " bytes)"));
1151 }
1152 return Status::OK();
1153 }
1154
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1155 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1156 const TensorVector& inputs,
1157 TensorVector* output) const {
1158 return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1159 resource_mgr_.get(), output);
1160 }
1161
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1162 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1163 std::vector<NodeDef>* outputs,
1164 bool* result_too_large) {
1165 TensorVector inputs;
1166 TensorVector output_tensors;
1167 auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1168 for (const auto& input : inputs) {
1169 delete input.tensor;
1170 }
1171 for (const auto& output : output_tensors) {
1172 if (output.tensor) {
1173 delete output.tensor;
1174 }
1175 }
1176 });
1177
1178 size_t total_inputs_size = 0;
1179 for (const auto& input : node.input()) {
1180 const TensorId input_tensor = ParseTensorName(input);
1181 if (input_tensor.index() < 0) {
1182 // Control dependency
1183 break;
1184 }
1185 const NodeDef* input_node = node_map_->GetNode(input);
1186 if (!IsReallyConstant(*input_node)) {
1187 return Status(error::INVALID_ARGUMENT,
1188 strings::StrCat("Can't fold ", node.name(), ", its ", input,
1189 " isn't constant"));
1190 }
1191 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1192 const TensorProto& raw_val = input_node->attr().at("value").tensor();
1193 Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1194 CHECK(value->FromProto(raw_val));
1195 inputs.emplace_back(value);
1196 total_inputs_size += value->TotalBytes();
1197 }
1198
1199 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1200 if (output_tensors.empty()) {
1201 return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1202 }
1203
1204 outputs->resize(output_tensors.size());
1205 for (size_t i = 0; i < output_tensors.size(); i++) {
1206 string node_name = OptimizedNodeName(node, "-folded");
1207 if (output_tensors.size() > 1) {
1208 node_name = strings::StrCat(node_name, "-", i);
1209 }
1210 if (output_tensors[i].tensor) {
1211 Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1212 total_inputs_size);
1213 if (!s.ok()) {
1214 *result_too_large = true;
1215 return s;
1216 }
1217 } else {
1218 // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1219 // switch that's not selected by the switch predicate).
1220 outputs->at(i) = NodeDef();
1221 }
1222 }
1223 return Status::OK();
1224 }
1225
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1226 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1227 // Merge nodes are special, in the sense that they execute as soon as one of
1228 // their input is ready. We can therefore fold a merge node iff it has at
1229 // least one constant input without control dependency.
1230 // We still need to ensure that the nodes in the fanin of the merge node are
1231 // scheduled. We'll therefore add a control dependency from the merge node
1232 // to the folded constant. We end up with:
1233 // * the merge node and its inputs are preserved as is
1234 // * a new constant node C1, driven by the merge node through a control
1235 // dependency, initialized to the value of the folded input
1236 // * a new constant node C2, driven by the merge node through a control
1237 // dependency, initialized to the index of the folded input
1238 // * the fanout of the merge nodes is rewired to be driven by either C1 or
1239 // C2.
1240 for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1241 const auto& input = node->input(input_index);
1242 if (IsControlInput(input)) {
1243 // Try the next input.
1244 continue;
1245 }
1246 NodeDef* input_node = node_map_->GetNode(input);
1247 if (!IsReallyConstant(*input_node)) {
1248 continue;
1249 }
1250 bool valid_input = true;
1251 for (const string& fanin_of_input : input_node->input()) {
1252 if (IsControlInput(fanin_of_input)) {
1253 valid_input = false;
1254 break;
1255 }
1256 }
1257 if (!valid_input) {
1258 // Try the next input
1259 continue;
1260 }
1261
1262 string const_out_name = OptimizedNodeName(*node, "_const");
1263 string const_index_name = OptimizedNodeName(*node, "_index");
1264 if (node_map_->GetNode(const_out_name) ||
1265 node_map_->GetNode(const_index_name)) {
1266 // Intended name already exists.
1267 return errors::AlreadyExists(
1268 strings::StrCat(const_out_name, " or ", const_index_name,
1269 " already present in the graph"));
1270 }
1271
1272 NodeDef* const_out = output_graph->add_node();
1273 *const_out = *input_node;
1274 const_out->set_name(const_out_name);
1275 const_out->set_device(node->device());
1276 *const_out->add_input() = AsControlDependency(*node);
1277 node_map_->AddNode(const_out->name(), const_out);
1278 node_map_->AddOutput(node->name(), const_out->name());
1279
1280 NodeDef* const_index = output_graph->add_node();
1281 const_index->set_op("Const");
1282 Tensor index(DT_INT32, TensorShape({}));
1283 index.flat<int32>()(0) = input_index;
1284 (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1285 index.AsProtoTensorContent(
1286 (*const_index->mutable_attr())["value"].mutable_tensor());
1287 const_index->set_name(const_index_name);
1288 const_index->set_device(node->device());
1289 *const_index->add_input() = AsControlDependency(*node);
1290 node_map_->AddNode(const_index->name(), const_index);
1291 node_map_->AddOutput(node->name(), const_index->name());
1292
1293 auto outputs = node_map_->GetOutputs(node->name());
1294 for (NodeDef* output : outputs) {
1295 for (int i = 0; i < output->input_size(); i++) {
1296 int port;
1297 string node_name = ParseNodeName(output->input(i), &port);
1298 if (node_name == node->name()) {
1299 if (port == 0) {
1300 *output->mutable_input(i) = const_out->name();
1301 node_map_->AddOutput(const_out->name(), output->name());
1302 } else if (port == 1) {
1303 *output->mutable_input(i) = const_index->name();
1304 node_map_->AddOutput(const_index->name(), output->name());
1305 } else {
1306 // This is a control dependency (or an invalid edge since the
1307 // merge node has only 2 inputs): preserve them.
1308 }
1309 }
1310 }
1311 }
1312 return Status::OK();
1313 }
1314 return Status::OK();
1315 }
1316
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1317 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1318 bool* result_too_large) {
1319 *result_too_large = false;
1320 if (IsMerge(*node)) {
1321 return FoldMergeNode(node, output_graph);
1322 }
1323
1324 std::vector<NodeDef> const_nodes;
1325 TF_RETURN_IF_ERROR(
1326 EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1327 VLOG(1) << "Folded node:\n" << node->DebugString();
1328
1329 NodeDef* constant_output = nullptr;
1330 for (int i = 0; i < const_nodes.size(); i++) {
1331 NodeDef* const_node = &const_nodes[i];
1332 VLOG(1) << "Generated constant node:\n" << const_node->DebugString();
1333 if (const_node->name().empty()) {
1334 // Dead output: we can't create a constant to encode its value, so we'll
1335 // just skip it. We'll preserve the edges that originate from that
1336 // output below to preserve the overall behavior of the graph wrt dead
1337 // edges.
1338 continue;
1339 }
1340
1341 // Forward control dependencies.
1342 for (const auto& input : node->input()) {
1343 if (IsControlInput(input) &&
1344 std::find(const_node->input().begin(), const_node->input().end(),
1345 input) == const_node->input().end()) {
1346 *const_node->add_input() = input;
1347 } else {
1348 NodeDef* input_node = node_map_->GetNode(input);
1349 for (const auto& fanin_of_input : input_node->input()) {
1350 if (IsControlInput(fanin_of_input) &&
1351 std::find(const_node->input().begin(), const_node->input().end(),
1352 fanin_of_input) == const_node->input().end()) {
1353 *const_node->add_input() = fanin_of_input;
1354 }
1355 }
1356 }
1357 }
1358
1359 // We rewrite the existing node if it only has a single output, and
1360 // create new nodes otherwise.
1361 if (const_nodes.size() == 1) {
1362 node->set_op("Const");
1363 // Note we need to clear the inputs in NodeMap before we clear the inputs
1364 // in the node, otherwise NodeMap would see empty inputs and effectively
1365 // does nothing.
1366 node_map_->RemoveInputs(node->name());
1367 node->clear_input();
1368 *node->mutable_input() = const_node->input();
1369 for (const auto& input : node->input()) {
1370 node_map_->AddOutput(NodeName(input), node->name());
1371 }
1372 *node->mutable_attr() = const_node->attr();
1373 break;
1374 } else {
1375 if (node_map_->GetNode(const_node->name())) {
1376 // Intended name already exists.
1377 return errors::AlreadyExists(strings::StrCat(
1378 const_node->name(), " already present in the graph"));
1379 }
1380 NodeDef* added_node = output_graph->add_node();
1381 *added_node = *const_node;
1382 added_node->set_device(node->device());
1383 node_map_->AddNode(added_node->name(), added_node);
1384 for (const auto& input : added_node->input()) {
1385 node_map_->AddOutput(NodeName(input), added_node->name());
1386 }
1387 // All the constant nodes encoding output values have the same control
1388 // dependencies (since these are the control dependencies of the node
1389 // we're trying to fold). Record one such constant node.
1390 constant_output = added_node;
1391 }
1392 }
1393
1394 if (const_nodes.size() > 1) {
1395 auto outputs = node_map_->GetOutputs(node->name());
1396 for (NodeDef* output : outputs) {
1397 for (int i = 0; i < output->input_size(); i++) {
1398 int port;
1399 string node_name = ParseNodeName(output->input(i), &port);
1400 if (node_name == node->name()) {
1401 if (port < 0) {
1402 // Propagate control dependencies if possible. If not, we'll just
1403 // preserve the existing control dependencies.
1404 if (constant_output != nullptr) {
1405 node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1406 constant_output->name());
1407 *output->mutable_input(i) = AsControlDependency(*constant_output);
1408 }
1409 } else if (port < const_nodes.size() &&
1410 !const_nodes[port].name().empty()) {
1411 // Replace alive outputs with the corresponding constant.
1412 node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1413 const_nodes[port].name());
1414 *output->mutable_input(i) = const_nodes[port].name();
1415 } else {
1416 // Leave this edge alone.
1417 VLOG(1) << "Preserving edge from " << node->name() << ":" << port
1418 << "[" << node->op() << "] to " << output->name() << ":"
1419 << i << "[" << output->op() << "]";
1420 }
1421 }
1422 }
1423 }
1424 outputs = node_map_->GetOutputs(node->name());
1425 if (outputs.empty() && has_fetch_ &&
1426 nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1427 node_map_->RemoveInputs(node->name());
1428 node->clear_input();
1429 }
1430 }
1431 return Status::OK();
1432 }
1433
FoldGraph(GraphDef * output,absl::flat_hash_set<string> * nodes_to_not_simplify)1434 Status ConstantFolding::FoldGraph(
1435 GraphDef* output, absl::flat_hash_set<string>* nodes_to_not_simplify) {
1436 std::unordered_set<string> processed_nodes;
1437 std::deque<NodeDef*> queue;
1438 for (int i = 0; i < graph_->node_size(); i++) {
1439 if (IsFoldable(graph_->node(i))) {
1440 queue.push_back(graph_->mutable_node(i));
1441 }
1442 }
1443 while (!queue.empty()) {
1444 NodeDef* node = queue.front();
1445 queue.pop_front();
1446 if (processed_nodes.count(node->name())) {
1447 continue;
1448 }
1449 // We need to record a copy of output nodes before FoldNode() modifies it.
1450 // We also need to ensure that the fanout is sorted deterministically.
1451 const std::set<NodeDef*>& outputs = node_map_->GetOutputs(node->name());
1452 std::vector<NodeDef*> fanout(outputs.begin(), outputs.end());
1453 std::sort(fanout.begin(), fanout.end(),
1454 [](const NodeDef* n1, const NodeDef* n2) {
1455 return n1->name() < n2->name();
1456 });
1457
1458 bool result_too_large = false;
1459 Status s = FoldNode(node, output, &result_too_large);
1460 processed_nodes.insert(node->name());
1461 if (!s.ok()) {
1462 VLOG(1) << "Failed to fold node " << node->DebugString()
1463 << "\nError message: " << s;
1464 if (result_too_large) {
1465 nodes_to_not_simplify->emplace(node->name());
1466 }
1467 } else {
1468 for (auto& output : fanout) {
1469 if (IsFoldable(*output)) {
1470 queue.push_back(output);
1471 }
1472 }
1473 }
1474 }
1475
1476 // Delete the newly created nodes that don't feed anything.
1477 std::vector<int> nodes_to_delete;
1478 for (int i = 0; i < output->node_size(); i++) {
1479 auto fanout = node_map_->GetOutputs(output->node(i).name());
1480 if (fanout.empty()) nodes_to_delete.push_back(i);
1481 }
1482 EraseNodesFromGraph(std::move(nodes_to_delete), output);
1483
1484 for (const auto& node : graph_->node()) {
1485 // If no fetch nodes is provided, we conservatively
1486 // keep all nodes in the original graph in case users need to fetch
1487 // their values.
1488 auto fanout = node_map_->GetOutputs(node.name());
1489 if (!fanout.empty() || !has_fetch_ ||
1490 nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
1491 auto added_node = output->add_node();
1492 *added_node = node;
1493 }
1494 }
1495 return Status::OK();
1496 }
1497
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1498 bool ConstantFolding::IsSimplifiableReshape(
1499 const NodeDef& node, const GraphProperties& properties) const {
1500 if (!IsReshape(node)) {
1501 return false;
1502 }
1503 CHECK_LE(2, node.input_size());
1504 const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1505 if (!IsReallyConstant(*new_shape)) {
1506 return false;
1507 }
1508 TensorVector outputs;
1509 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1510 for (const auto& output : outputs) {
1511 delete output.tensor;
1512 }
1513 });
1514
1515 Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1516 if (!s.ok()) {
1517 return false;
1518 }
1519 CHECK_EQ(1, outputs.size());
1520
1521 const std::vector<OpInfo::TensorProperties>& props =
1522 properties.GetInputProperties(node.name());
1523 if (props.empty()) {
1524 return false;
1525 }
1526 const OpInfo::TensorProperties& prop = props[0];
1527 if (prop.dtype() == DT_INVALID) {
1528 return false;
1529 }
1530 const PartialTensorShape shape(prop.shape());
1531 if (!shape.IsFullyDefined()) {
1532 return false;
1533 }
1534
1535 PartialTensorShape new_dims;
1536 if (outputs[0]->dtype() == DT_INT32) {
1537 std::vector<int32> shp;
1538 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1539 int32 dim = outputs[0]->flat<int32>()(i);
1540 shp.push_back(dim);
1541 }
1542 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1543 } else {
1544 std::vector<int64> shp;
1545 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1546 int64 dim = outputs[0]->flat<int64>()(i);
1547 shp.push_back(dim);
1548 }
1549 TF_CHECK_OK(TensorShapeUtils::MakeShape(shp, &new_dims));
1550 }
1551
1552 return shape.IsCompatibleWith(new_dims);
1553 }
1554
1555 #define IS_VALUE_CASE(DTYPE, VALUE) \
1556 case DTYPE: \
1557 return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1558 node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1559
1560 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1561 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1562
IsOnes(const NodeDef & node) const1563 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1564 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1565 return false;
1566 }
1567 if (IsOnesLike(node)) return true;
1568 if (IsZerosLike(node)) return false;
1569 if (node.op() == "Fill") {
1570 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1571 return values != nullptr && IsOnes(*values);
1572 }
1573 if (node.op() != "Const") return false;
1574 if (node.attr().count("dtype") == 0) return false;
1575 const auto dtype = node.attr().at("dtype").type();
1576 switch (dtype) {
1577 IS_ONES_CASE(DT_BOOL);
1578 IS_ONES_CASE(DT_HALF);
1579 IS_ONES_CASE(DT_BFLOAT16);
1580 IS_ONES_CASE(DT_FLOAT);
1581 IS_ONES_CASE(DT_DOUBLE);
1582 IS_ONES_CASE(DT_COMPLEX64);
1583 IS_ONES_CASE(DT_COMPLEX128);
1584 IS_ONES_CASE(DT_UINT8);
1585 IS_ONES_CASE(DT_INT8);
1586 IS_ONES_CASE(DT_UINT16);
1587 IS_ONES_CASE(DT_INT16);
1588 IS_ONES_CASE(DT_INT32);
1589 IS_ONES_CASE(DT_INT64);
1590 IS_ONES_CASE(DT_QINT32);
1591 IS_ONES_CASE(DT_QINT16);
1592 IS_ONES_CASE(DT_QUINT16);
1593 IS_ONES_CASE(DT_QINT8);
1594 IS_ONES_CASE(DT_QUINT8);
1595 default:
1596 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1597 return false;
1598 }
1599 return false;
1600 }
1601
IsZeros(const NodeDef & node) const1602 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1603 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1604 return false;
1605 }
1606 if (IsOnesLike(node)) return false;
1607 if (IsZerosLike(node)) return true;
1608 if (node.op() == "Fill") {
1609 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1610 return values != nullptr && IsZeros(*values);
1611 }
1612 if (!IsConstant(node)) return false;
1613 if (node.attr().count("dtype") == 0) return false;
1614 const auto dtype = node.attr().at("dtype").type();
1615 switch (dtype) {
1616 IS_ZEROS_CASE(DT_BOOL);
1617 IS_ZEROS_CASE(DT_HALF);
1618 IS_ZEROS_CASE(DT_BFLOAT16);
1619 IS_ZEROS_CASE(DT_FLOAT);
1620 IS_ZEROS_CASE(DT_DOUBLE);
1621 IS_ZEROS_CASE(DT_COMPLEX64);
1622 IS_ZEROS_CASE(DT_COMPLEX128);
1623 IS_ZEROS_CASE(DT_UINT8);
1624 IS_ZEROS_CASE(DT_INT8);
1625 IS_ZEROS_CASE(DT_UINT16);
1626 IS_ZEROS_CASE(DT_INT16);
1627 IS_ZEROS_CASE(DT_INT32);
1628 IS_ZEROS_CASE(DT_INT64);
1629 IS_ZEROS_CASE(DT_QINT32);
1630 IS_ZEROS_CASE(DT_QINT16);
1631 IS_ZEROS_CASE(DT_QUINT16);
1632 IS_ZEROS_CASE(DT_QINT8);
1633 IS_ZEROS_CASE(DT_QUINT8);
1634 default:
1635 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1636 return false;
1637 }
1638 return false;
1639 }
1640
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1641 void ConstantFolding::ReplaceOperationWithIdentity(
1642 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1643 GraphDef* graph) {
1644 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1645 if (dtype == DT_INVALID) return;
1646
1647 node->set_op("Identity");
1648 node->clear_attr();
1649 (*node->mutable_attr())["T"].set_type(dtype);
1650 // Propagate the designated input through the identity.
1651 node->mutable_input()->SwapElements(0, input_to_forward);
1652 // Add all other inputs as control dependencies.
1653 for (int i = 1; i < node->input_size(); ++i) {
1654 if (IsControlInput(node->input(i))) {
1655 break;
1656 }
1657 const string ctrl_dep =
1658 AddControlDependency(node->input(i), graph, node_map_.get());
1659 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1660 node->set_input(i, ctrl_dep);
1661 }
1662 graph_modified_ = true;
1663 }
1664
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1665 void ConstantFolding::ReplaceOperationWithSnapshot(
1666 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1667 GraphDef* graph) {
1668 // If the graph contains no ops that mutate their inputs, we can
1669 // use Identity insted of Snapshot.
1670 if (!graph_contains_assign_or_inplace_op_) {
1671 ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1672 return;
1673 }
1674
1675 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1676 if (dtype == DT_INVALID) return;
1677
1678 node->set_op("Snapshot");
1679 node->clear_attr();
1680 (*node->mutable_attr())["T"].set_type(dtype);
1681 // Propagate the designated input through the Snapshot.
1682 node->mutable_input()->SwapElements(0, input_to_forward);
1683 // Add all other inputs as control dependencies.
1684 for (int i = 1; i < node->input_size(); ++i) {
1685 if (IsControlInput(node->input(i))) {
1686 break;
1687 }
1688 const string ctrl_dep =
1689 AddControlDependency(node->input(i), graph, node_map_.get());
1690 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1691 node->set_input(i, ctrl_dep);
1692 }
1693 graph_modified_ = true;
1694 }
1695
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)1696 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
1697 GraphDef* graph) {
1698 node->set_op("Reciprocal");
1699 node->mutable_input()->SwapElements(0, 1);
1700 const string ctrl_dep =
1701 AddControlDependency(node->input(1), graph, node_map_.get());
1702 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1703 node->set_input(1, ctrl_dep);
1704 graph_modified_ = true;
1705 }
1706
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)1707 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
1708 GraphDef* graph) {
1709 node->set_op("Neg");
1710 node->mutable_input()->SwapElements(0, 1);
1711 const string ctrl_dep =
1712 AddControlDependency(node->input(1), graph, node_map_.get());
1713 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
1714 node->set_input(1, ctrl_dep);
1715 graph_modified_ = true;
1716 }
1717
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph,bool * success)1718 Status ConstantFolding::ReplaceOperationWithConstant(
1719 double value, const GraphProperties& properties,
1720 const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
1721 bool* success) {
1722 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1723 if (dtype == DT_INVALID) {
1724 *success = false;
1725 return Status::OK();
1726 }
1727
1728 AttrValue tensor_attr;
1729 TF_RETURN_IF_ERROR(
1730 CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr));
1731 node->set_op("Const");
1732 node->clear_attr();
1733 (*node->mutable_attr())["dtype"].set_type(dtype);
1734 node->mutable_attr()->insert({"value", tensor_attr});
1735 // Convert all inputs to control dependencies.
1736 for (int i = 0; i < node->input_size(); ++i) {
1737 if (IsControlInput(node->input(i))) {
1738 break;
1739 }
1740 const string ctrl_dep =
1741 AddControlDependency(node->input(i), graph, node_map_.get());
1742 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1743 node->set_input(i, ctrl_dep);
1744 }
1745 *success = true;
1746 graph_modified_ = true;
1747 return Status::OK();
1748 }
1749
SimplifyGraph(bool use_shape_info,GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)1750 Status ConstantFolding::SimplifyGraph(
1751 bool use_shape_info, GraphDef* optimized_graph, GraphProperties* properties,
1752 absl::flat_hash_set<string>* nodes_to_not_simplify) {
1753 for (int i = 0; i < optimized_graph->node_size(); ++i) {
1754 NodeDef* node = optimized_graph->mutable_node(i);
1755 // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
1756 // generalize to only restrict certain simplifications.
1757 if (nodes_to_not_simplify->find(node->name()) ==
1758 nodes_to_not_simplify->end()) {
1759 if (HasTPUAttributes(optimized_graph->node(i))) {
1760 nodes_to_not_simplify->insert(node->name());
1761 continue;
1762 }
1763 TF_RETURN_IF_ERROR(
1764 SimplifyNode(use_shape_info, node, optimized_graph, properties));
1765 }
1766 }
1767 return Status::OK();
1768 }
1769
SimplifyNode(bool use_shape_info,NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)1770 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
1771 GraphDef* optimized_graph,
1772 GraphProperties* properties) {
1773 if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
1774 return Status::OK();
1775 }
1776
1777 bool remove_shuffle_transpose_successful = false;
1778 Status remove_shuffle_transpose_status =
1779 RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
1780 node, &remove_shuffle_transpose_successful);
1781 if (!remove_shuffle_transpose_status.ok()) {
1782 return remove_shuffle_transpose_status;
1783 } else if (remove_shuffle_transpose_successful) {
1784 return Status::OK();
1785 }
1786
1787 if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
1788 return Status::OK();
1789 }
1790
1791 bool remove_reverse_successful = false;
1792 Status remove_reverse_status =
1793 RemoveReverse(*properties, use_shape_info, optimized_graph, node,
1794 &remove_reverse_successful);
1795 if (!remove_reverse_status.ok()) {
1796 return remove_reverse_status;
1797 } else if (remove_reverse_successful) {
1798 return Status::OK();
1799 }
1800
1801 bool simplify_slice_successful = false;
1802 Status simplify_slice_status =
1803 SimplifySlice(*properties, use_shape_info, optimized_graph, node,
1804 &simplify_slice_successful);
1805 if (!simplify_slice_status.ok()) {
1806 return simplify_slice_status;
1807 } else if (simplify_slice_successful) {
1808 return Status::OK();
1809 }
1810
1811 bool simplify_strided_slice_successful = false;
1812 Status simplify_strided_slice_status =
1813 SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
1814 &simplify_strided_slice_successful);
1815 if (!simplify_strided_slice_status.ok()) {
1816 return simplify_strided_slice_status;
1817 } else if (simplify_strided_slice_successful) {
1818 return Status::OK();
1819 }
1820
1821 bool simplify_tile_successful = false;
1822 Status simplify_tile_status =
1823 SimplifyTile(*properties, use_shape_info, optimized_graph, node,
1824 &simplify_tile_successful);
1825 if (!simplify_tile_status.ok()) {
1826 return simplify_tile_status;
1827 } else if (simplify_tile_successful) {
1828 return Status::OK();
1829 }
1830
1831 bool simplify_pad_successful = false;
1832 Status simplify_pad_status =
1833 SimplifyPad(*properties, use_shape_info, optimized_graph, node,
1834 &simplify_pad_successful);
1835 if (!simplify_pad_status.ok()) {
1836 return simplify_pad_status;
1837 } else if (simplify_pad_successful) {
1838 return Status::OK();
1839 }
1840
1841 if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
1842 return Status::OK();
1843 }
1844
1845 if (SimplifyPack(optimized_graph, node)) {
1846 graph_modified_ = true;
1847 return Status::OK();
1848 }
1849
1850 if (MoveConstantsPastEnter(optimized_graph, node)) {
1851 graph_modified_ = true;
1852 return Status::OK();
1853 }
1854
1855 if (SimplifySwitch(optimized_graph, node)) {
1856 graph_modified_ = true;
1857 return Status::OK();
1858 }
1859
1860 if (SimplifyReduction(optimized_graph, *properties, node)) {
1861 graph_modified_ = true;
1862 return Status::OK();
1863 }
1864
1865 if (SimplifyReshape(*properties, use_shape_info, node)) {
1866 graph_modified_ = true;
1867 return Status::OK();
1868 }
1869
1870 bool arithmetic_simplification_succeed = false;
1871 Status simplify_arithmetic_status =
1872 SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
1873 node, &arithmetic_simplification_succeed);
1874 if (!simplify_arithmetic_status.ok()) {
1875 return simplify_arithmetic_status;
1876 } else if (arithmetic_simplification_succeed) {
1877 graph_modified_ = true;
1878 return Status::OK();
1879 }
1880
1881 if (ReduceDivToReciprocalMul(optimized_graph, node)) {
1882 graph_modified_ = true;
1883 return Status::OK();
1884 }
1885
1886 if (ConstantPushDown(optimized_graph, node)) {
1887 graph_modified_ = true;
1888 return Status::OK();
1889 }
1890
1891 if (MulConvPushDown(optimized_graph, node, *properties)) {
1892 graph_modified_ = true;
1893 return Status::OK();
1894 }
1895
1896 if (PartialConstPropThroughIdentityN(node)) {
1897 graph_modified_ = true;
1898 return Status::OK();
1899 }
1900
1901 if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
1902 graph_modified_ = true;
1903 return Status::OK();
1904 }
1905
1906 if (PartialConcatConstFolding(optimized_graph, properties, node)) {
1907 graph_modified_ = true;
1908 return Status::OK();
1909 }
1910
1911 if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
1912 graph_modified_ = true;
1913 return Status::OK();
1914 }
1915
1916 return Status::OK();
1917 }
1918
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)1919 bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
1920 GraphDef* optimized_graph,
1921 NodeDef* node) {
1922 if (node->attr().count("num_split") == 0) return false;
1923 if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
1924 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
1925 return true;
1926 }
1927 if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
1928 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1929 return true;
1930 }
1931 return false;
1932 }
1933
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)1934 Status ConstantFolding::RemoveShuffleOrTranspose(
1935 const GraphProperties& properties, bool use_shape_info,
1936 GraphDef* optimized_graph, NodeDef* node, bool* success) {
1937 if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
1938 properties.GetInputProperties(node->name()).size() >= 2) {
1939 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
1940 if (shape.unknown_rank()) {
1941 // Not optimizable.
1942 return Status::OK();
1943 }
1944 const auto& p = properties.GetInputProperties(node->name())[1];
1945 if (TensorShape::IsValid(p.shape()) && p.has_value()) {
1946 Tensor perm(p.dtype(), p.shape());
1947 if (!perm.FromProto(p.value())) {
1948 return errors::InvalidArgument("Cannot parse tensor from proto: ",
1949 p.value().DebugString());
1950 }
1951 std::vector<int> permutation;
1952 for (int j = 0; j < perm.NumElements(); ++j) {
1953 if (perm.dtype() == DT_INT64) {
1954 permutation.push_back(perm.vec<int64>()(j));
1955 } else {
1956 permutation.push_back(perm.vec<int>()(j));
1957 }
1958 }
1959 if (permutation.size() != shape.dim_size()) {
1960 // Number of elements in perm should be same as dim_size. Skip if not.
1961 return Status::OK();
1962 }
1963 // The node is replaceable iff
1964 // dim_size == 0 || all dims have size 1 ||
1965 // all dims with > 1 size are not permuted.
1966 bool replaceable = true;
1967 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
1968 replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
1969 }
1970 if (replaceable) {
1971 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1972 *success = true;
1973 return Status::OK();
1974 }
1975 }
1976 }
1977 *success = false;
1978 return Status::OK();
1979 }
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)1980 bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
1981 bool use_shape_info,
1982 GraphDef* optimized_graph,
1983 NodeDef* node) {
1984 if (use_shape_info && IsRandomShuffle(*node) &&
1985 !properties.GetInputProperties(node->name()).empty()) {
1986 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
1987 // The node is replaceable iff
1988 // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
1989 if (!shape.unknown_rank() &&
1990 (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
1991 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
1992 return true;
1993 }
1994 }
1995 return false;
1996 }
1997
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)1998 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
1999 bool use_shape_info,
2000 GraphDef* optimized_graph, NodeDef* node,
2001 bool* success) {
2002 if (use_shape_info && node->op() == "ReverseV2" &&
2003 properties.GetInputProperties(node->name()).size() >= 2) {
2004 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2005 if (shape.unknown_rank()) {
2006 // Not optimizable.
2007 return Status::OK();
2008 }
2009 const auto& a = properties.GetInputProperties(node->name())[1];
2010 if (TensorShape::IsValid(a.shape()) && a.has_value()) {
2011 Tensor axis(a.dtype(), a.shape());
2012 if (!axis.FromProto(a.value())) {
2013 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2014 a.value().DebugString());
2015 }
2016 std::set<int> target_axes;
2017 for (int j = 0; j < axis.NumElements(); ++j) {
2018 // value of axis can be negative.
2019 if (axis.dtype() == DT_INT64) {
2020 target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
2021 shape.dim_size());
2022 } else {
2023 target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2024 shape.dim_size());
2025 }
2026 }
2027
2028 // The node is replaceable iff
2029 // unknown_rank == false &&
2030 // (dim_size == 0 || all dims have size 1 ||
2031 // all dims with > 1 size are not in target_axes)
2032 bool replaceable = !shape.unknown_rank();
2033 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2034 replaceable &= shape.dim(j).size() == 1 ||
2035 target_axes.find(j) == target_axes.end();
2036 }
2037 if (replaceable) {
2038 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2039 *success = true;
2040 return Status::OK();
2041 }
2042 }
2043 }
2044 *success = false;
2045 return Status::OK();
2046 }
2047
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2048 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2049 bool use_shape_info,
2050 GraphDef* optimized_graph, NodeDef* node,
2051 bool* success) {
2052 if (use_shape_info && IsSlice(*node) &&
2053 properties.GetInputProperties(node->name()).size() == 3) {
2054 const auto& input = properties.GetInputProperties(node->name())[0];
2055 const auto& b = properties.GetInputProperties(node->name())[1];
2056 const auto& s = properties.GetInputProperties(node->name())[2];
2057 if (TensorShape::IsValid(b.shape()) && b.has_value() &&
2058 TensorShape::IsValid(s.shape()) && s.has_value()) {
2059 Tensor begin(b.dtype(), b.shape());
2060 if (!begin.FromProto(b.value())) {
2061 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2062 b.value().DebugString());
2063 }
2064 Tensor size(s.dtype(), s.shape());
2065 if (!size.FromProto(s.value())) {
2066 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2067 s.value().DebugString());
2068 }
2069 // The node is replaceable iff unknown_rank == false &&
2070 // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2071 bool replaceable = !input.shape().unknown_rank();
2072 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2073 if (begin.dtype() == DT_INT32) {
2074 replaceable &= begin.vec<int>()(j) == 0;
2075 } else {
2076 replaceable &= begin.vec<int64>()(j) == 0;
2077 }
2078 if (size.dtype() == DT_INT32) {
2079 replaceable &= (size.vec<int>()(j) == -1 ||
2080 size.vec<int>()(j) == input.shape().dim(j).size());
2081 } else {
2082 replaceable &= (size.vec<int64>()(j) == -1 ||
2083 size.vec<int64>()(j) == input.shape().dim(j).size());
2084 }
2085 }
2086 if (replaceable) {
2087 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2088 *success = true;
2089 return Status::OK();
2090 }
2091 }
2092 }
2093 *success = false;
2094 return Status::OK();
2095 }
2096
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2097 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2098 bool use_shape_info,
2099 GraphDef* optimized_graph,
2100 NodeDef* node, bool* success) {
2101 if (use_shape_info && IsStridedSlice(*node) &&
2102 properties.GetInputProperties(node->name()).size() == 4) {
2103 TF_RETURN_IF_ERROR(
2104 CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2105 if (node->attr().at("new_axis_mask").i() != 0 ||
2106 node->attr().at("shrink_axis_mask").i() != 0) {
2107 // Skip nodes with new/shrink axis mask, since they involve dimension
2108 // changes.
2109 return Status::OK();
2110 }
2111 const auto& input = properties.GetInputProperties(node->name())[0];
2112 for (int j = 0; j < input.shape().dim_size(); ++j) {
2113 // Skip if input shape is not fully determined.
2114 if (input.shape().dim(j).size() < 0) {
2115 return Status::OK();
2116 }
2117 }
2118 const auto& b = properties.GetInputProperties(node->name())[1];
2119 const auto& e = properties.GetInputProperties(node->name())[2];
2120 const auto& s = properties.GetInputProperties(node->name())[3];
2121 if (TensorShape::IsValid(b.shape()) && b.has_value() &&
2122 TensorShape::IsValid(e.shape()) && e.has_value() &&
2123 TensorShape::IsValid(s.shape()) && s.has_value()) {
2124 Tensor begin(b.dtype(), b.shape());
2125 if (!begin.FromProto(b.value())) {
2126 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2127 b.value().DebugString());
2128 }
2129 Tensor end(e.dtype(), e.shape());
2130 if (!end.FromProto(e.value())) {
2131 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2132 e.value().DebugString());
2133 }
2134 Tensor strides(s.dtype(), s.shape());
2135 if (!strides.FromProto(s.value())) {
2136 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2137 s.value().DebugString());
2138 }
2139 TF_RETURN_IF_ERROR(
2140 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2141 int begin_mask = node->attr().at("begin_mask").i();
2142 int end_mask = node->attr().at("end_mask").i();
2143 std::set<int> expanded_ellipsis_indices;
2144 int ellipsis_index = -1;
2145 for (int j = 0; j < input.shape().dim_size(); ++j) {
2146 // find the ellipsis_mask. If not found, insert one in the end if
2147 // necessary.
2148 if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2149 (ellipsis_index == -1 && j >= strides.NumElements())) {
2150 ellipsis_index = j;
2151 }
2152 // insert the indices that are immediately after ellipsis_index if
2153 // necessary.
2154 if (ellipsis_index != -1 &&
2155 input.shape().dim_size() >
2156 strides.NumElements() + j - ellipsis_index) {
2157 expanded_ellipsis_indices.insert(j);
2158 }
2159 }
2160
2161 // The node is replaceable iff unknown_rank == false &&
2162 // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2163 // && strides == 1) for all dimensions.
2164 bool replaceable = !input.shape().unknown_rank();
2165 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2166 if (expanded_ellipsis_indices.find(j) !=
2167 expanded_ellipsis_indices.end()) {
2168 // ellipsis_mask is effective on current dimension.
2169 continue;
2170 }
2171 // when we have ellipsis_mask in between, input.shape().dim_size() will
2172 // be greater than strides.NumElements(), since we will insert
2173 // as many as expanded_ellipsis_indices.size() axes during computation.
2174 // We need to subtract this number from j.
2175 int i = j;
2176 if (ellipsis_index != -1 &&
2177 j >= ellipsis_index + expanded_ellipsis_indices.size()) {
2178 i = j - expanded_ellipsis_indices.size();
2179 }
2180 int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2181 : begin.vec<int64>()(i);
2182 int e =
2183 end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
2184 int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2185 : strides.vec<int64>()(i);
2186 replaceable &=
2187 (begin_mask & 1 << i || b == 0) &&
2188 (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
2189 }
2190 if (replaceable) {
2191 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2192 *success = true;
2193 return Status::OK();
2194 }
2195 }
2196 }
2197 *success = false;
2198 return Status::OK();
2199 }
2200
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2201 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2202 bool use_shape_info,
2203 GraphDef* optimized_graph, NodeDef* node,
2204 bool* success) {
2205 if (use_shape_info && IsTile(*node) &&
2206 properties.GetInputProperties(node->name()).size() == 2) {
2207 const auto& m = properties.GetInputProperties(node->name())[1];
2208 if (TensorShape::IsValid(m.shape()) && m.has_value()) {
2209 Tensor multiplies(m.dtype(), m.shape());
2210 if (!multiplies.FromProto(m.value())) {
2211 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2212 m.value().DebugString());
2213 }
2214 // The node is replaceable iff all values in multiplies are 1.
2215 bool replaceable = true;
2216 if (multiplies.dtype() == DT_INT32) {
2217 for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2218 replaceable &= multiplies.vec<int>()(j) == 1;
2219 }
2220 } else {
2221 for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
2222 ++j) {
2223 replaceable &= multiplies.vec<int64>()(j) == 1;
2224 }
2225 }
2226 if (replaceable) {
2227 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2228 *success = true;
2229 return Status::OK();
2230 }
2231 }
2232 }
2233 *success = false;
2234 return Status::OK();
2235 }
2236
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2237 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2238 bool use_shape_info,
2239 GraphDef* optimized_graph, NodeDef* node,
2240 bool* success) {
2241 if (use_shape_info && IsPad(*node) &&
2242 properties.GetInputProperties(node->name()).size() >= 2) {
2243 const auto& p = properties.GetInputProperties(node->name())[1];
2244 if (TensorShape::IsValid(p.shape()) && p.has_value()) {
2245 Tensor paddings(p.dtype(), p.shape());
2246 if (!paddings.FromProto(p.value())) {
2247 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2248 p.value().DebugString());
2249 }
2250 // The node is replaceable iff all values in paddings are 0.
2251 bool replaceable = true;
2252 // The operation requires it to be int32 value so we don't check for
2253 // 1nt64.
2254 const auto flatten = paddings.flat<int32>();
2255 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2256 replaceable &= flatten(j) == 0;
2257 }
2258 if (replaceable) {
2259 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2260 *success = true;
2261 return Status::OK();
2262 }
2263 }
2264 }
2265 *success = false;
2266 return Status::OK();
2267 }
2268
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2269 bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2270 bool use_shape_info,
2271 GraphDef* optimized_graph,
2272 NodeDef* node) {
2273 if (use_shape_info && IsSqueeze(*node) &&
2274 !properties.GetInputProperties(node->name()).empty()) {
2275 // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2276 // error to squeeze a dimension that is not 1, so we only need to check
2277 // whether the input has > 1 size for each dimension.
2278 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2279 // The node is replaceable iff
2280 // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2281 bool replaceable = !shape.unknown_rank();
2282 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2283 replaceable &= shape.dim(j).size() > 1;
2284 }
2285 if (replaceable) {
2286 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2287 return true;
2288 }
2289 }
2290 return false;
2291 }
2292
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2293 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2294 if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
2295 !OptimizedNodeExists(*node, "_const_axis")) {
2296 // Create constant axis node.
2297 Tensor axis_t(DT_INT32, TensorShape({}));
2298 NodeDef* axis_node = optimized_graph->add_node();
2299 axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
2300 const int axis =
2301 node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2302 if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2303 !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
2304 .ok()) {
2305 return false;
2306 }
2307 // Add a control dependency to make sure axis_node is in the right frame.
2308 const string ctrl_dep = ConstantFolding::AddControlDependency(
2309 node->input(0), optimized_graph, node_map_.get());
2310 axis_node->add_input(ctrl_dep);
2311 axis_node->set_device(node->device());
2312 node->set_op("ExpandDims");
2313 if (node->attr().count("axis") != 0) {
2314 node->mutable_attr()->erase("axis");
2315 }
2316 if (node->attr().count("N") != 0) {
2317 node->mutable_attr()->erase("N");
2318 }
2319 (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2320 node->add_input(axis_node->name());
2321 if (node->input_size() > 2) {
2322 node->mutable_input()->SwapElements(1, node->input_size() - 1);
2323 }
2324 return true;
2325 }
2326 return false;
2327 }
2328
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2329 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2330 NodeDef* node) {
2331 if (IsEnter(*node) && node->input_size() > 0) {
2332 if (node->attr().count("is_constant") == 0 ||
2333 !node->attr().at("is_constant").b()) {
2334 return false;
2335 }
2336 const string& node_name = node->name();
2337 const NodeDef* input = node_map_->GetNode(node->input(0));
2338 if (input != nullptr && IsReallyConstant(*input) &&
2339 !OptimizedNodeExists(*input, "_enter")) {
2340 auto fanouts = node_map_->GetOutputs(node_name);
2341 // Find non-constant nodes that consume the output of *node.
2342 std::vector<NodeDef*> consumers;
2343 for (NodeDef* fanout : fanouts) {
2344 if (!IsConstant(*fanout)) {
2345 for (int i = 0; i < fanout->input_size(); ++i) {
2346 if (fanout->input(i) == node_name) {
2347 consumers.push_back(fanout);
2348 break;
2349 }
2350 }
2351 }
2352 }
2353 if (!consumers.empty()) {
2354 NodeDef* new_node = optimized_graph->add_node();
2355 *new_node = *input;
2356 new_node->set_name(OptimizedNodeName(*input, "_enter"));
2357 new_node->set_device(node->device());
2358 new_node->clear_input();
2359 new_node->add_input(AsControlDependency(node_name));
2360 node_map_->AddNode(new_node->name(), new_node);
2361 node_map_->AddOutput(node_name, new_node->name());
2362 for (NodeDef* consumer : consumers) {
2363 for (int i = 0; i < consumer->input_size(); ++i) {
2364 if (NodeName(consumer->input(i)) == node_name) {
2365 node_map_->UpdateInput(consumer->name(), node_name,
2366 new_node->name());
2367 consumer->set_input(i, new_node->name());
2368 }
2369 }
2370 }
2371 return true;
2372 }
2373 }
2374 }
2375 return false;
2376 }
2377
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2378 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2379 if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2380 !OptimizedNodeExists(*node, "_const_false") &&
2381 !OptimizedNodeExists(*node, "_const_true")) {
2382 bool already_optimized = true;
2383 // If the optimization was already applied, the switch would have exactly
2384 // one Identity node consuming each of its outputs, each without any
2385 // non-control outputs.
2386 auto fanouts = node_map_->GetOutputs(node->name());
2387 if (fanouts.size() == 2) {
2388 for (NodeDef* fanout : fanouts) {
2389 if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2390 NumNonControlOutputs(*fanout, *node_map_) > 0) {
2391 already_optimized = false;
2392 break;
2393 }
2394 }
2395 }
2396 Tensor false_t(DT_BOOL, TensorShape({}));
2397 Tensor true_t(DT_BOOL, TensorShape({}));
2398 // Make sure we don't proceed if this switch node was already optimized.
2399 if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2400 SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2401 // Copy the set of consumers of the switch as they will be manipulated
2402 // below.
2403 const std::set<NodeDef*>& consumer_set =
2404 node_map_->GetOutputs(node->name());
2405 std::vector<NodeDef*> consumers(consumer_set.begin(), consumer_set.end());
2406 std::sort(consumers.begin(), consumers.end(),
2407 [](const NodeDef* n1, const NodeDef* n2) {
2408 return n1->name() < n2->name();
2409 });
2410 // Create constant false & true nodes.
2411 NodeDef* false_node = optimized_graph->add_node();
2412 false_node->set_name(OptimizedNodeName(*node, "_const_false"));
2413 if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
2414 .ok()) {
2415 return false;
2416 }
2417 false_node->set_device(node->device());
2418
2419 NodeDef* true_node = optimized_graph->add_node();
2420 true_node->set_name(OptimizedNodeName(*node, "_const_true"));
2421 if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
2422 .ok()) {
2423 return false;
2424 }
2425 true_node->set_device(node->device());
2426
2427 // Add controls from the switch ports to the constants, and connect the
2428 // constants to the original switch outputs.
2429 const string false_port = node->name();
2430 const string true_port = strings::StrCat(node->name(), ":1");
2431 const string false_ctrl_dep =
2432 AddControlDependency(false_port, optimized_graph, node_map_.get());
2433 false_node->add_input(false_ctrl_dep);
2434 const string true_ctrl_dep =
2435 AddControlDependency(true_port, optimized_graph, node_map_.get());
2436 true_node->add_input(true_ctrl_dep);
2437
2438 node_map_->AddNode(false_node->name(), false_node);
2439 node_map_->AddNode(true_node->name(), true_node);
2440 node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2441 node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2442
2443 for (NodeDef* consumer : consumers) {
2444 for (int i = 0; i < consumer->input_size(); ++i) {
2445 const string& input = consumer->input(i);
2446 if (input == false_port) {
2447 consumer->set_input(i, false_node->name());
2448 node_map_->UpdateInput(consumer->name(), false_port,
2449 false_node->name());
2450 } else if (input == true_port) {
2451 consumer->set_input(i, true_node->name());
2452 node_map_->UpdateInput(consumer->name(), true_port,
2453 true_node->name());
2454 }
2455 }
2456 }
2457 return true;
2458 }
2459 }
2460 return false;
2461 }
2462
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2463 bool ConstantFolding::IsReductionCandidateForSimplification(
2464 const NodeDef& node, const GraphProperties& properties,
2465 TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2466 bool* is_single_element_op) const {
2467 // Ensure its an appropriate Reduce node.
2468 if (!IsReduction(node) || node.input_size() < 2) {
2469 return false;
2470 }
2471 // Ensure that the axes to reduce by are constant.
2472 NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2473 if (!IsReallyConstant(*reductions_indices)) {
2474 return false;
2475 }
2476
2477 // Get the properties of the input & output tensors and check if they both
2478 // contain a single element.
2479 if (!properties.HasInputProperties(node.name()) ||
2480 !properties.HasOutputProperties(node.name())) {
2481 return false;
2482 }
2483 const auto& input_props = properties.GetInputProperties(node.name())[0];
2484 const auto& output_props = properties.GetOutputProperties(node.name())[0];
2485 if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2486 !output_props.has_shape() || output_props.shape().unknown_rank()) {
2487 return false;
2488 }
2489 *input_tensor_shape = input_props.shape();
2490 *output_tensor_shape = output_props.shape();
2491 for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2492 if (input_tensor_shape->dim(i).size() < 0) {
2493 return false;
2494 }
2495 }
2496 for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2497 if (output_tensor_shape->dim(i).size() < 0) {
2498 return false;
2499 }
2500 }
2501 const int input_num_elements =
2502 TensorShape(*input_tensor_shape).num_elements();
2503 const int output_num_elements =
2504 TensorShape(*output_tensor_shape).num_elements();
2505 *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2506
2507 return true;
2508 }
2509
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2510 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2511 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2512 const TensorVector& reduction_indices_vector) const {
2513 int output_size = reduction_indices_vector[0]->NumElements();
2514 if (output_size == 0) {
2515 return true;
2516 }
2517
2518 if (!keep_dims) {
2519 return false;
2520 }
2521 bool simplifiable = true;
2522 for (int i = 0; i < output_size; ++i) {
2523 int64 dim;
2524 if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2525 dim = reduction_indices_vector[0]->flat<int32>()(i);
2526 } else {
2527 dim = reduction_indices_vector[0]->flat<int64>()(i);
2528 }
2529 if (dim < 0) {
2530 dim += input_shape.dim_size();
2531 }
2532 if (dim < 0 || dim >= input_shape.dim_size() ||
2533 input_shape.dim(dim).size() != 1) {
2534 simplifiable = false;
2535 break;
2536 }
2537 }
2538 return simplifiable;
2539 }
2540
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2541 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2542 const GraphProperties& properties,
2543 NodeDef* node) {
2544 bool is_single_element_op = false;
2545 TensorShapeProto input_tensor_shape, output_tensor_shape;
2546 if (!IsReductionCandidateForSimplification(
2547 *node, properties, &input_tensor_shape, &output_tensor_shape,
2548 &is_single_element_op)) {
2549 return false;
2550 }
2551
2552 // Get the reduction indices.
2553 string reduction_indices_input = node->input(1);
2554 NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2555 TensorVector reduction_indices_vector;
2556 auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2557 for (const auto& out : reduction_indices_vector) {
2558 delete out.tensor;
2559 }
2560 });
2561 if (!EvaluateNode(*reduction_indices, TensorVector(),
2562 &reduction_indices_vector)
2563 .ok() ||
2564 reduction_indices_vector.size() != 1) {
2565 return false;
2566 }
2567
2568 bool keep_dims =
2569 node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2570 bool simplifiable_to_reshape =
2571 is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2572 bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2573 *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2574
2575 if (simplifiable_to_reshape) {
2576 // Const node to output shape.
2577 const int new_num_dimensions = output_tensor_shape.dim_size();
2578 Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2579 for (int i = 0; i < new_num_dimensions; i++) {
2580 tensor.flat<int>()(i) = 1;
2581 }
2582 TensorValue shape_value(&tensor);
2583 NodeDef* shape_node = optimized_graph->add_node();
2584 if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2585 shape_node)
2586 .ok()) {
2587 return false;
2588 }
2589 shape_node->set_device(node->device());
2590 node_map_->AddNode(shape_node->name(), shape_node);
2591 // Control dependency to ensure shape_node is in the correct frame.
2592 shape_node->add_input(AsControlDependency(reduction_indices_input));
2593 node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2594 // Optimize node to Reshape.
2595 node->set_op("Reshape");
2596 node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2597 node->set_input(1, shape_node->name());
2598 node->mutable_attr()->erase("keep_dims");
2599 node->mutable_attr()->erase("Tidx");
2600 AttrValue attr_type_indices;
2601 attr_type_indices.set_type(DT_INT32);
2602 (*node->mutable_attr())["Tshape"] = attr_type_indices;
2603 return true;
2604 } else if (simplifiable_to_identity) {
2605 // Replace the reduction node with an identity node, that can be further
2606 // optimized by the model pruner.
2607 DataType output_type;
2608 if (node->attr().count("T") != 0) {
2609 output_type = node->attr().at("T").type();
2610 } else {
2611 // This is an 'any' or 'all' reduction. The output is always boolean.
2612 output_type = DT_BOOL;
2613 }
2614 node->set_op("Identity");
2615 node->clear_attr();
2616 (*node->mutable_attr())["T"].set_type(output_type);
2617 *node->mutable_input(1) = AsControlDependency(node->input(1));
2618 return true;
2619 }
2620 return false;
2621 }
2622
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2623 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2624 bool use_shape_info, NodeDef* node) {
2625 if (!use_shape_info || node->attr().count("T") == 0 ||
2626 !IsSimplifiableReshape(*node, properties)) {
2627 return false;
2628 }
2629 DataType output_type = node->attr().at("T").type();
2630 node->set_op("Identity");
2631 node->clear_attr();
2632 (*node->mutable_attr())["T"].set_type(output_type);
2633 *node->mutable_input(1) = AsControlDependency(node->input(1));
2634 return true;
2635 }
2636
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node,bool * success)2637 Status ConstantFolding::SimplifyArithmeticOperations(
2638 const GraphProperties& properties, bool use_shape_info,
2639 GraphDef* optimized_graph, NodeDef* node, bool* success) {
2640 *success = false;
2641 const bool is_mul = IsMul(*node) || IsLogicalAnd(*node);
2642 const bool is_matmul = IsMatMul(*node);
2643 const bool is_quantized_matmul = IsQuantizedMatMul(*node);
2644 const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2645 const bool is_sub = IsSub(*node);
2646 const bool is_any_div = IsAnyDiv(*node);
2647 // Simplify arithmetic operations with ones or zeros.
2648 if (use_shape_info &&
2649 (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2650 properties.HasInputProperties(node->name()) &&
2651 properties.HasOutputProperties(node->name())) {
2652 const NodeDef* x = node_map_->GetNode(node->input(0));
2653 const NodeDef* y = node_map_->GetNode(node->input(1));
2654 if (x == nullptr || y == nullptr) {
2655 return errors::InvalidArgument("Invalid inputs to node: ",
2656 node->DebugString());
2657 }
2658 const TensorShapeProto& output_shape =
2659 properties.GetOutputProperties(node->name())[0].shape();
2660
2661 // Simplify element-wise multiplication by ones or addition/subtraction
2662 // of zeros.
2663 const TensorShapeProto& y_shape =
2664 properties.GetInputProperties(node->name())[1].shape();
2665 const bool x_is_zero = IsZeros(*x);
2666 const bool x_is_one = x_is_zero ? false : IsOnes(*x);
2667 const bool y_matches_output_shape =
2668 ShapesSymbolicallyEqual(output_shape, y_shape);
2669 if (y_matches_output_shape &&
2670 ((is_mul && x_is_one) || (is_add && x_is_zero))) {
2671 // 1 * y = y or 0 + y = y.
2672 ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
2673 *success = true;
2674 return Status::OK();
2675 }
2676
2677 if (y_matches_output_shape && (is_sub && x_is_zero)) {
2678 // Replace 0 - y with Neg(y).
2679 ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
2680 *success = true;
2681 return Status::OK();
2682 }
2683
2684 // Replace 1 / y with Reciprocal op.
2685 if (y_matches_output_shape && is_any_div && x_is_one) {
2686 TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
2687 DataType type = node->attr().at("T").type();
2688 if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
2689 ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
2690 *success = true;
2691 return Status::OK();
2692 }
2693 }
2694
2695 const TensorShapeProto& x_shape =
2696 properties.GetInputProperties(node->name())[0].shape();
2697 const bool y_is_zero = IsZeros(*y);
2698 const bool y_is_one = y_is_zero ? false : IsOnes(*y);
2699 const bool x_matches_output_shape =
2700 ShapesSymbolicallyEqual(output_shape, x_shape);
2701 if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
2702 ((is_add || is_sub) && y_is_zero))) {
2703 // x * 1 = x or x / 1 = x or x +/- 0 = x
2704 ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
2705 *success = true;
2706 return Status::OK();
2707 }
2708
2709 // x OR true = true OR y = true.
2710 bool updated_graph = false;
2711 const PartialTensorShape shp(output_shape);
2712 if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
2713 bool replace_succeed = false;
2714 Status replace_op_status = ReplaceOperationWithConstant(
2715 1, properties, output_shape, node, optimized_graph, &replace_succeed);
2716 if (!replace_op_status.ok()) {
2717 return replace_op_status;
2718 } else if (replace_succeed) {
2719 updated_graph = true;
2720 }
2721 }
2722
2723 // Simplify multiplication and matmul by zeros.
2724 // Also optimize zeros divided by a tensor, but only if we are in
2725 // aggressive mode, since we might get rid of divisions by zero.
2726 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
2727 bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
2728 if ((x_is_zero || y_is_zero) &&
2729 (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
2730 if (shp.IsFullyDefined()) {
2731 bool replace_succeed = false;
2732 Status replace_op_status =
2733 ReplaceOperationWithConstant(0, properties, output_shape, node,
2734 optimized_graph, &replace_succeed);
2735 if (!replace_op_status.ok()) {
2736 return replace_op_status;
2737 } else if (replace_succeed) {
2738 if (is_quantized_matmul) {
2739 TF_RETURN_IF_ERROR(
2740 AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
2741 }
2742 *success = true;
2743 return Status::OK();
2744 }
2745 }
2746 // Even if an input shape is only partially known, we may known that it
2747 // matches the output shape and thus forward the corresponding zero
2748 // input.
2749 if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
2750 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2751 *success = true;
2752 return Status::OK();
2753 } else if (is_mul && y_is_zero && y_matches_output_shape) {
2754 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2755 *success = true;
2756 return Status::OK();
2757 }
2758 }
2759 if (updated_graph) {
2760 *success = true;
2761 return Status::OK();
2762 }
2763 }
2764 *success = false;
2765 return Status::OK();
2766 }
2767
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)2768 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
2769 NodeDef* node) {
2770 // Strength reduce floating point division by a constant Div(x, const) to
2771 // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
2772 // will be constant folded to Mul(x, 1.0/const).
2773 if (node->input_size() >= 2 && (IsRealDiv(*node) || IsDiv(*node))) {
2774 const string& const_input = node->input(1);
2775 const NodeDef* denom = node_map_->GetNode(const_input);
2776 CHECK(denom != nullptr);
2777 if (!IsReallyConstant(*denom)) {
2778 return false;
2779 }
2780 if (node->attr().count("T") == 0) {
2781 return false;
2782 }
2783 DataType type = node->attr().at("T").type();
2784 if (IsDiv(*node) &&
2785 !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
2786 return false;
2787 }
2788 // Insert new reciprocal op and change node from Div to Mul.
2789 NodeDef* reciprocal_node = optimized_graph->add_node();
2790 reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
2791 reciprocal_node->set_op("Reciprocal");
2792 reciprocal_node->set_device(node->device());
2793 node->set_op("Mul");
2794 // Re-wire inputs and outputs.
2795 reciprocal_node->add_input(const_input);
2796 (*reciprocal_node->mutable_attr())["T"].set_type(type);
2797 node->set_input(1, reciprocal_node->name());
2798 node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
2799 node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
2800 return true;
2801 }
2802
2803 return false;
2804 }
2805
ConstantPushDown(GraphDef * optimized_graph,NodeDef * node)2806 bool ConstantFolding::ConstantPushDown(GraphDef* optimized_graph,
2807 NodeDef* node) {
2808 // Consider the transformation
2809 //
2810 // + + = parent
2811 // / \ / \
2812 // C + -- > X + = children
2813 // / \ / \
2814 // X Y C Y = leaves
2815 //
2816 // where C is constant and X is non-constant, and '+' denotes an
2817 // associative and commutative operator like addition or multiplication.
2818 // This optimization pushes constants down in the tree to canonicalize it.
2819 // Moreoever, in cases where the child node has a second constant input Y
2820 // we will create a leaf node that can be folded, e.g.
2821 //
2822 // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
2823 //
2824 // TODO(rmlarsen): Handle non-associative/non-commutative operators like
2825 // subtraction and division, as well as mixed subtraction/addition,
2826 // division/multiplication.
2827 // Don't touch BiasAdd since they can't handle vectors as their first
2828 // inputs.
2829 if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) &&
2830 NumNonControlInputs(*node) == 2) {
2831 NodeDef* left_child = node_map_->GetNode(node->input(0));
2832 NodeDef* right_child = node_map_->GetNode(node->input(1));
2833 // One child must be constant, and the other the same op as the parent.
2834 if (node->op() != left_child->op() && node->op() != right_child->op()) {
2835 return false;
2836 }
2837 const bool left_child_is_constant = IsReallyConstant(*left_child);
2838 const bool right_child_is_constant = IsReallyConstant(*right_child);
2839 if (!left_child_is_constant && !right_child_is_constant) {
2840 return false;
2841 }
2842 if (node->device() != left_child->device() ||
2843 node->device() != right_child->device()) {
2844 return false;
2845 }
2846 NodeDef* op_child_node = left_child_is_constant ? right_child : left_child;
2847 NodeDef* const_child_node =
2848 left_child_is_constant ? left_child : right_child;
2849 // Make sure that it is safe to change the value of the child node->
2850 if (op_child_node->input_size() < 2 ||
2851 nodes_to_preserve_.find(op_child_node->name()) !=
2852 nodes_to_preserve_.end() ||
2853 NumNonControlOutputs(*op_child_node, *node_map_) > 1) {
2854 return false;
2855 }
2856
2857 // Identify the nodes to swap.
2858 NodeDef* left_leaf = node_map_->GetNode(op_child_node->input(0));
2859 NodeDef* right_leaf = node_map_->GetNode(op_child_node->input(1));
2860 const bool left_leaf_is_constant = IsReallyConstant(*left_leaf);
2861 const bool right_leaf_is_constant = IsReallyConstant(*right_leaf);
2862 if (left_leaf_is_constant && right_leaf_is_constant) {
2863 // Child is already foldable, leave it alone.
2864 return false;
2865 }
2866 const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0;
2867 const int parent_const_input = left_child_is_constant ? 0 : 1;
2868 const auto& child_output = node_map_->GetOutputs(op_child_node->name());
2869 if (child_output.find(const_child_node) != child_output.end()) {
2870 // If there is a control edge from the child op to C, the transformation
2871 // would create a cycle in the graph. We know that it must be a control
2872 // edge. We can replace such a control edge with a control edge from A
2873 // to C.
2874 CHECK(MaybeRemoveControlInput(op_child_node->name(), const_child_node,
2875 optimized_graph, node_map_.get()));
2876 string other_leaf_input = left_leaf_is_constant ? op_child_node->input(0)
2877 : op_child_node->input(1);
2878 MaybeAddControlInput(other_leaf_input, const_child_node, optimized_graph,
2879 node_map_.get());
2880 }
2881
2882 // Swap the constant child with a non-constant leaf node.
2883 node_map_->UpdateInput(node->name(), node->input(parent_const_input),
2884 op_child_node->input(non_const_leaf_input));
2885 node_map_->UpdateInput(op_child_node->name(),
2886 op_child_node->input(non_const_leaf_input),
2887 node->input(parent_const_input));
2888 std::swap(*node->mutable_input(parent_const_input),
2889 *op_child_node->mutable_input(non_const_leaf_input));
2890 return true;
2891 }
2892 return false;
2893 }
2894
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)2895 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
2896 const GraphProperties& properties) {
2897 // Push down multiplication on ConvND.
2898 // * ConvND
2899 // / \ / \
2900 // ConvND C2 -- > X *
2901 // / \ / \
2902 // X C1 C1 C2
2903 //
2904 // where C1 and C2 are constants and X is non-constant.
2905 if (!IsMul(*node) || NumNonControlInputs(*node) != 2) return false;
2906
2907 NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
2908 NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
2909 // One child must be constant, and the second must be Conv op.
2910 const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
2911 const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
2912 if (!left_child_is_constant && !right_child_is_constant) {
2913 return false;
2914 }
2915 NodeDef* conv_node =
2916 left_child_is_constant ? mul_right_child : mul_left_child;
2917 if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
2918 return false;
2919 }
2920 if (node->device() != mul_left_child->device() ||
2921 node->device() != mul_right_child->device()) {
2922 return false;
2923 }
2924
2925 // Make sure that it is safe to change the value of the convolution
2926 // output.
2927 if (conv_node->input_size() < 2 ||
2928 NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
2929 nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
2930 return false;
2931 }
2932
2933 // Identify the nodes to swap.
2934 NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
2935 NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
2936 const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
2937 const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
2938 if (!conv_left_is_constant && !conv_right_is_constant) {
2939 // At least one of the convolution inputs should be constant.
2940 return false;
2941 }
2942 if (conv_left_is_constant && conv_right_is_constant) {
2943 // Leverage regular constant folding to handle this.
2944 return false;
2945 }
2946 const auto& mul_props = properties.GetOutputProperties(node->name());
2947 const auto& conv_props = properties.GetOutputProperties(conv_node->name());
2948 if (mul_props.empty() || conv_props.empty()) {
2949 return false;
2950 }
2951 const auto& mul_shape = mul_props[0].shape();
2952 const auto& conv_shape = conv_props[0].shape();
2953 if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
2954 return false;
2955 }
2956
2957 const auto& input_props = properties.GetInputProperties(conv_node->name());
2958 if (input_props.size() < 2) {
2959 return false;
2960 }
2961 const auto& filter_shape = input_props[1].shape();
2962
2963 NodeDef* const_node =
2964 left_child_is_constant ? mul_left_child : mul_right_child;
2965 const auto& const_props = properties.GetOutputProperties(const_node->name());
2966 if (const_props.empty()) {
2967 return false;
2968 }
2969 const auto& const_shape = const_props[0].shape();
2970 if (!IsValidConstShapeForMulConvPushDown(
2971 conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
2972 return false;
2973 }
2974
2975 string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
2976 if (node_map_->NodeExists(mul_new_name)) {
2977 return false;
2978 }
2979 // Make sure we don't introduce loops in the graph by removing control
2980 // dependencies from the conv2d node to c2.
2981 string conv_const_input =
2982 conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
2983 if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
2984 node_map_.get())) {
2985 // Add a control dep from c1 to c2 to ensure c2 is in the right frame
2986 MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
2987 node_map_.get());
2988 }
2989
2990 conv_node->set_name(node->name());
2991 node->set_name(mul_new_name);
2992 if (conv_left_is_constant) {
2993 node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
2994 conv_node->set_input(0, mul_new_name);
2995 } else {
2996 node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
2997 conv_node->set_input(1, mul_new_name);
2998 }
2999 NodeDef* conv_const_node =
3000 conv_left_is_constant ? conv_left_child : conv_right_child;
3001 if (left_child_is_constant) {
3002 node->set_input(1, conv_const_node->name());
3003 } else {
3004 node->set_input(0, conv_const_node->name());
3005 }
3006 node_map_->AddNode(mul_new_name, node);
3007
3008 return true;
3009 }
3010
PartialConstPropThroughIdentityN(NodeDef * node)3011 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3012 // Partial constant propagation through IdentityN.
3013 if ((IsIdentityN(*node) || IsIdentityNSingleInput(*node)) &&
3014 NumNonControlInputs(*node) > 0) {
3015 const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
3016 const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
3017 bool updated_graph = false;
3018 for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3019 const string& input = node->input(input_idx);
3020 if (IsControlInput(input)) {
3021 break;
3022 }
3023 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3024 if (input_node == nullptr) {
3025 LOG(ERROR) << "Bad input: " << input;
3026 break;
3027 }
3028 // Forward constant inputs to outputs and add a control dependency on
3029 // the IdentityN node.
3030 if (IsReallyConstant(*input_node)) {
3031 // Update each consumer.
3032 for (NodeDef* consumer : consumers) {
3033 bool add_dep = false;
3034 for (int consumer_input_idx = 0;
3035 consumer_input_idx < consumer->input_size();
3036 ++consumer_input_idx) {
3037 const string& consumer_input = consumer->input(consumer_input_idx);
3038 if (IsControlInput(consumer_input)) {
3039 break;
3040 }
3041 int output_idx;
3042 const string input_node_name =
3043 ParseNodeName(consumer_input, &output_idx);
3044 if (input_node_name == node->name() && output_idx == input_idx) {
3045 consumer->set_input(consumer_input_idx, input);
3046 // We will keep the input from IdentityN through a control
3047 // dependency, so we only need to add the consumer as an output
3048 // for the constant input node.
3049 node_map_->AddOutput(NodeName(input), consumer->name());
3050 add_dep = true;
3051 }
3052 }
3053 if (add_dep) {
3054 consumer->add_input(AsControlDependency(node->name()));
3055 updated_graph = true;
3056 }
3057 }
3058 }
3059 }
3060
3061 if (updated_graph) {
3062 for (NodeDef* consumer : consumers) {
3063 DedupControlInputs(consumer);
3064 }
3065 return true;
3066 }
3067 }
3068 return false;
3069 }
3070
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3071 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3072 GraphProperties* properties,
3073 NodeDef* node) {
3074 // Partial constant folding for associative operators:
3075 // Split AddN/AccumulateNV2 to enable partial
3076 // folding of ops when more than one but not all inputs are constant.
3077 // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3078 // addition is commutative.
3079 const int num_non_control_inputs = NumNonControlInputs(*node);
3080 if (IsAggregate(*node) && IsCommutative(*node) &&
3081 num_non_control_inputs > 2) {
3082 const int num_control_inputs = node->input_size() - num_non_control_inputs;
3083 std::vector<int> const_inputs;
3084 std::vector<int> nonconst_inputs;
3085 for (int i = 0; i < node->input_size(); ++i) {
3086 const string& input = node->input(i);
3087 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3088 CHECK(input_node != nullptr) << input;
3089 if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3090 const_inputs.push_back(i);
3091 } else {
3092 // Non-const and control inputs.
3093 nonconst_inputs.push_back(i);
3094 }
3095 }
3096 // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3097 // a fake node that cannot be constant folded by itself.
3098 if (const_inputs.size() == num_non_control_inputs &&
3099 node->op() == "AccumulateNV2") {
3100 node->set_op("AddN");
3101 node->mutable_attr()->erase("shape");
3102 return true;
3103 }
3104 const string new_node_name = OptimizedNodeName(
3105 *node, strings::StrCat("_partial_split_", const_inputs.size()));
3106 if (1 < const_inputs.size() &&
3107 const_inputs.size() < num_non_control_inputs &&
3108 !node_map_->NodeExists(new_node_name)) {
3109 NodeDef* added_node = optimized_graph->add_node();
3110 *added_node = *node;
3111 // Always use AddN for the constant node, since AccumulateNV2 is a fake
3112 // node that cannot be constant folded, since it does not have a kernel.
3113 added_node->set_op("AddN");
3114 added_node->mutable_attr()->erase("shape");
3115 added_node->set_name(new_node_name);
3116 node_map_->AddNode(added_node->name(), added_node);
3117 added_node->clear_input();
3118 for (int i : const_inputs) {
3119 added_node->add_input(node->input(i));
3120 node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3121 added_node->name());
3122 }
3123
3124 // Overwrite the first const input with the added node.
3125 node->set_input(const_inputs[0], added_node->name());
3126 node_map_->AddOutput(added_node->name(), node->name());
3127 nonconst_inputs.push_back(const_inputs[0]);
3128 // Compact the remaining inputs to the original node.
3129 std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3130 int idx = 0;
3131 for (int i : nonconst_inputs) {
3132 if (idx != i) {
3133 node->set_input(idx, node->input(i));
3134 }
3135 ++idx;
3136 }
3137 node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3138 const_inputs.size() - 1);
3139 (*node->mutable_attr())["N"].set_i(node->input_size() -
3140 num_control_inputs);
3141 properties->ClearInputProperties(node->name());
3142 (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3143 return true;
3144 }
3145 }
3146 return false;
3147 }
3148
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3149 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3150 GraphProperties* properties,
3151 NodeDef* node) {
3152 // Partial constant folding for Concat which is not commutative, so
3153 // we have to preserve order and can only push consecutive runs of constant
3154 // inputs into sub-nodes.
3155 const int num_non_control_inputs = NumNonControlInputs(*node);
3156 if (IsConcat(*node) && num_non_control_inputs > 3 &&
3157 node->name().rfind("_partial_split_") == string::npos) {
3158 int axis_arg = -1;
3159 int begin = 0;
3160 int end = num_non_control_inputs;
3161 if (node->op() == "Concat") {
3162 begin = 1;
3163 axis_arg = 0;
3164 } else if (node->op() == "ConcatV2") {
3165 end = num_non_control_inputs - 1;
3166 axis_arg = num_non_control_inputs - 1;
3167 } else {
3168 return false;
3169 }
3170
3171 const NodeDef* axis_arg_node =
3172 node_map_->GetNode(NodeName(node->input(axis_arg)));
3173 if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
3174 // We cannot constant fold Concat unless we the axis argument is
3175 // constant. Skip node.
3176 return false;
3177 }
3178
3179 // We search for consecutive runs of constant inputs in the range
3180 // [begin:end[ and push then down into child nodes.
3181 std::vector<std::pair<int, int>> constant_input_runs;
3182 int first = begin;
3183 int last = begin;
3184 while (last < end) {
3185 while (first < end && !IsReallyConstant(*node_map_->GetNode(
3186 NodeName(node->input(first))))) {
3187 ++first;
3188 }
3189 // Invariant: node[first] is constant || first >= end.
3190 last = first + 1;
3191 while (last < end && IsReallyConstant(*node_map_->GetNode(
3192 NodeName(node->input(last))))) {
3193 ++last;
3194 }
3195 // Invariant: node[last] is not constant || last >= end
3196 // Discard intervals shorter than 2 elements.
3197 if (first < end && (last - first) > 1) {
3198 constant_input_runs.emplace_back(first, last);
3199 }
3200 first = last;
3201 }
3202
3203 // Skip if all inputs are constant, and let constant folding take over.
3204 if (constant_input_runs.size() == 1 &&
3205 constant_input_runs[0].first == begin &&
3206 constant_input_runs[0].second == end) {
3207 return false;
3208 }
3209 std::set<int> inputs_to_delete;
3210 for (auto interval : constant_input_runs) {
3211 // Push the constant inputs in the interval to a child node than can be
3212 // constant folded.
3213 const string new_node_name = OptimizedNodeName(
3214 *node, strings::StrCat("_partial_split_", interval.first));
3215 if (node_map_->NodeExists(new_node_name)) {
3216 break;
3217 }
3218 NodeDef* added_node = optimized_graph->add_node();
3219 *added_node = *node;
3220 added_node->set_name(new_node_name);
3221 node_map_->AddNode(added_node->name(), added_node);
3222 added_node->clear_input();
3223 for (int i = interval.first; i < interval.second; ++i) {
3224 added_node->add_input(node->input(i));
3225 node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3226 added_node->name());
3227 if (i != interval.first) {
3228 inputs_to_delete.insert(i);
3229 }
3230 }
3231 added_node->add_input(node->input(axis_arg));
3232 (*added_node->mutable_attr())["N"].set_i(interval.second -
3233 interval.first);
3234 node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3235
3236 // Overwrite the first constant input with the result of the added
3237 // child node.
3238 node->set_input(interval.first, added_node->name());
3239 node_map_->AddOutput(added_node->name(), node->name());
3240 }
3241 if (!constant_input_runs.empty()) {
3242 if (!inputs_to_delete.empty()) {
3243 // Fix up the inputs to the original node.
3244 std::vector<string> tmp(node->input().begin(), node->input().end());
3245 node->clear_input();
3246 for (int i = 0; i < tmp.size(); ++i) {
3247 if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3248 node->add_input(tmp[i]);
3249 }
3250 }
3251 (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3252 properties->ClearInputProperties(node->name());
3253 }
3254 return true;
3255 }
3256 }
3257 return false;
3258 }
3259
MergeConcat(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)3260 bool ConstantFolding::MergeConcat(const GraphProperties& properties,
3261 bool use_shape_info,
3262 GraphDef* optimized_graph, NodeDef* node) {
3263 // We only optimize for ConcatV2.
3264 int axis;
3265 if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
3266 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3267 node_map_->GetOutputs(node->name()).size() != 1) {
3268 return false;
3269 }
3270
3271 NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3272 int parent_axis;
3273 if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
3274 return false;
3275 }
3276
3277 const int index = NumNonControlInputs(*node) - 1;
3278 auto inputs = parent->input();
3279 parent->clear_input();
3280 for (int i = 0; i < inputs.size(); ++i) {
3281 if (IsSameInput(inputs.Get(i), node->name())) {
3282 for (int j = 0; j < node->input_size(); ++j) {
3283 if (j < index) {
3284 // Input tensors (non axis), add to input list of parent.
3285 parent->add_input(node->input(j));
3286 node_map_->RemoveOutput(node->input(j), node->name());
3287 node_map_->AddOutput(node->input(j), parent->name());
3288 }
3289 // Skip j == index, which means axis tensor.
3290 if (j > index) {
3291 // Control Dependencies, push back to inputs so they can be forwarded
3292 // to parent.
3293 *inputs.Add() = node->input(j);
3294 }
3295 }
3296 } else {
3297 parent->add_input(inputs.Get(i));
3298 }
3299 }
3300 node->clear_input();
3301 node->set_op("NoOp");
3302 node->clear_attr();
3303 node_map_->RemoveNode(node->name());
3304 (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3305
3306 return true;
3307 }
3308
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3309 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3310 NodeDef* node, GraphDef* optimized_graph) {
3311 auto add_quantized_out = [this, node, optimized_graph](
3312 const string& out_const_name, int index) {
3313 NodeDef* out_node = optimized_graph->add_node();
3314 Tensor value(DT_FLOAT, TensorShape({}));
3315 const bool is_min = index == 1;
3316 const DataType type_attr = node->attr().at("dtype").type();
3317
3318 value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3319 : QuantizedTypeMaxAsFloat(type_attr);
3320 TF_RETURN_IF_ERROR(
3321 CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3322 node_map_->AddNode(out_const_name, out_node);
3323 out_node->set_device(node->device());
3324
3325 // Copy all inputs from node.
3326 out_node->mutable_input()->CopyFrom(node->input());
3327 for (const string& input : out_node->input()) {
3328 node_map_->AddOutput(NodeName(input), out_const_name);
3329 }
3330
3331 // Update output nodes consuming node:index to new const node.
3332 string old_input = absl::StrCat(node->name(), ":", index);
3333 int old_node_count = 0;
3334 auto outputs = node_map_->GetOutputs(node->name());
3335 for (const auto& output : outputs) {
3336 for (int i = 0; i < output->input_size(); ++i) {
3337 if (output->input(i) == old_input) {
3338 output->set_input(i, out_const_name);
3339 node_map_->AddOutput(out_const_name, output->name());
3340 } else if (NodeName(output->input(i)) == node->name()) {
3341 ++old_node_count;
3342 }
3343 }
3344 if (old_node_count == 0) {
3345 node_map_->RemoveOutput(node->name(), output->name());
3346 }
3347 }
3348
3349 return Status::OK();
3350 };
3351 const string min_out_const_name =
3352 OptimizedNodeName(*node, "-quantized_matmul_min_out");
3353 const string max_out_const_name =
3354 OptimizedNodeName(*node, "-quantized_matmul_max_out");
3355 if (node_map_->GetNode(min_out_const_name) == nullptr &&
3356 node_map_->GetNode(max_out_const_name) == nullptr) {
3357 TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3358 TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3359 } else {
3360 return errors::Internal(absl::Substitute(
3361 "Can't create Const for QuantizedMatMul min_out/max_out of "
3362 "node '$0' because of node name conflict",
3363 node->name()));
3364 }
3365 return Status::OK();
3366 }
3367
RunOptimizationPass(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3368 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3369 const GrapplerItem& item,
3370 GraphDef* optimized_graph) {
3371 node_map_.reset(new NodeMap(graph_));
3372 nodes_whitelist_.clear();
3373 // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3374 // has a single fanout, it would be rewritten as a constant with the same
3375 // node name, and therefore users are still able to fetch it. This is not
3376 // the case if the node has multiple fanouts, and constant folding would
3377 // replace the node with multiple constants (each for one fanout) with
3378 // new names, and as a result users would not be able to fetch the node any
3379 // more with the original node name.
3380 for (const auto& fetch : item.fetch) {
3381 const NodeDef* fetch_node = node_map_->GetNode(fetch);
3382 if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
3383 nodes_whitelist_.insert(fetch_node->name());
3384 }
3385 }
3386
3387 GraphProperties properties(item);
3388 // It's possible to feed a placeholder with a tensor of any shape: make sure
3389 // that the shape inference deals with this conservatively unless we're in
3390 // aggressive mode.
3391 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3392 Status s = properties.InferStatically(assume_valid_feeds);
3393 const bool can_use_shape_info = s.ok();
3394
3395 if (can_use_shape_info) {
3396 TF_RETURN_IF_ERROR(MaterializeShapes(properties));
3397 TF_RETURN_IF_ERROR(MaterializeConstants(properties));
3398 }
3399 absl::flat_hash_set<string> nodes_to_not_simplify;
3400 TF_RETURN_IF_ERROR(FoldGraph(optimized_graph, &nodes_to_not_simplify));
3401 node_map_.reset(new NodeMap(optimized_graph));
3402 TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
3403 &properties, &nodes_to_not_simplify));
3404
3405 return Status::OK();
3406 }
3407
3408 namespace {
CompressConstants(GraphDef * graph)3409 Status CompressConstants(GraphDef* graph) {
3410 for (int i = 0; i < graph->node_size(); ++i) {
3411 NodeDef* node = graph->mutable_node(i);
3412 if ((IsConstant(*node) || IsHostConstant(*node)) &&
3413 HasNodeAttr(*node, "value")) {
3414 AttrValue& attr_val = (*node->mutable_attr())["value"];
3415 tensor::CompressTensorProtoInPlace(attr_val.mutable_tensor());
3416 }
3417 }
3418 return Status::OK();
3419 }
3420 } // namespace
3421
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)3422 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
3423 GraphDef* optimized_graph) {
3424 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
3425 // the same here.
3426 port::ScopedFlushDenormal flush;
3427 port::ScopedSetRound round(FE_TONEAREST);
3428 nodes_to_preserve_ = item.NodesToPreserve();
3429 for (const auto& feed : item.feed) {
3430 feed_nodes_.insert(NodeName(feed.first));
3431 }
3432
3433 if (cpu_device_ == nullptr) {
3434 owned_device_.reset(new DeviceSimple());
3435 cpu_device_ = owned_device_.get();
3436 }
3437
3438 graph_contains_assign_or_inplace_op_ = false;
3439 for (const NodeDef& node : item.graph.node()) {
3440 if (ModifiesInputsInPlace(node) || MaybeHasRefInput(node)) {
3441 graph_contains_assign_or_inplace_op_ = true;
3442 break;
3443 }
3444 }
3445
3446 has_fetch_ = !item.fetch.empty();
3447 GrapplerItem item_to_optimize = item;
3448 *optimized_graph = item.graph;
3449 int64 node_count;
3450 do {
3451 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3452 graph_modified_ = false;
3453 item_to_optimize.graph.Swap(optimized_graph);
3454 graph_ = &item_to_optimize.graph;
3455 *optimized_graph = GraphDef();
3456 node_count = graph_->node_size();
3457 TF_RETURN_IF_ERROR(
3458 RunOptimizationPass(cluster, item_to_optimize, optimized_graph));
3459 } while (graph_modified_ || optimized_graph->node_size() != node_count);
3460 TF_RETURN_IF_ERROR(CompressConstants(optimized_graph));
3461 *optimized_graph->mutable_library() = item.graph.library();
3462 *optimized_graph->mutable_versions() = item.graph.versions();
3463
3464 return Status::OK();
3465 }
3466
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimize_output,double result)3467 void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item,
3468 const GraphDef& optimize_output, double result) {
3469 // Nothing to do for ConstantFolding.
3470 }
3471
3472 } // namespace grappler
3473 } // namespace tensorflow
3474