1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "tensorflow/core/framework/attr_value.pb.h"
28 #include "tensorflow/core/framework/attr_value_util.h"
29 #include "tensorflow/core/framework/node_def.pb.h"
30 #include "tensorflow/core/framework/node_def_util.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.h"
35 #include "tensorflow/core/grappler/costs/graph_properties.h"
36 #include "tensorflow/core/grappler/graph_topology_view.h"
37 #include "tensorflow/core/grappler/grappler_item.h"
38 #include "tensorflow/core/grappler/op_types.h"
39 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
40 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/grappler/utils/topological_sort.h"
44 #include "tensorflow/core/grappler/utils/traversal.h"
45 #include "tensorflow/core/lib/core/errors.h"
46 #include "tensorflow/core/lib/core/stringpiece.h"
47 #include "tensorflow/core/lib/hash/hash.h"
48 #include "tensorflow/core/lib/strings/str_util.h"
49 #include "tensorflow/core/lib/strings/strcat.h"
50 #include "tensorflow/core/platform/tensor_coding.h"
51 #include "tensorflow/core/util/device_name_utils.h"
52 #include "tensorflow/core/util/saved_tensor_slice_util.h"
53 #include "tensorflow/core/util/strided_slice_op.h"
54
55 using tensorflow::str_util::StringReplace;
56 using tensorflow::strings::StrCat;
57
58 namespace tensorflow {
59 namespace grappler {
60 namespace {
61
62 // Mark nodes created or optimized by a stage with a tag.
63 constexpr char kAddOpsRewriteTag[] =
64 "_grappler:ArithmeticOptimizer:AddOpsRewriteStage";
65 constexpr char kMinimizeBroadcastsTag[] =
66 "_grappler:ArithmeticOptimizer:MinimizeBroadcasts";
67
68 // Extract values from a Const op to `values`. Returns true if succeeds.
69 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)70 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
71 if (node.op() != "Const") {
72 return false;
73 }
74
75 if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
76 node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
77 return false;
78 }
79
80 // TensorProto represents the content of the tensor in either <type>_val or
81 // tensor_content.
82 const TensorProto& tensor = node.attr().at("value").tensor();
83 typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
84 checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
85
86 if (!tensor_values->empty() && tensor.has_tensor_shape()) {
87 // When tensor_shape is set, theoretically the representation of the data
88 // could be compressed. So, before copying values to the returned vector,
89 // make sure no compression happens.
90 const TensorShapeProto& shape = tensor.tensor_shape();
91 if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
92 values->insert(values->end(), tensor_values->begin(),
93 tensor_values->end());
94 return true;
95 }
96 }
97
98 const auto tensor_content_size = tensor.tensor_content().size();
99 if (tensor_content_size > 0) {
100 CHECK_EQ(0, tensor_content_size % sizeof(T))
101 << "tensor_content_size (" << tensor_content_size
102 << ") is not a multiple of " << sizeof(T);
103 values->resize(tensor_content_size / sizeof(T));
104 port::CopyToArray(tensor.tensor_content(),
105 reinterpret_cast<char*>(values->data()));
106 return true;
107 }
108
109 return false;
110 }
111
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)112 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
113 GraphDef* graph, NodeMap* node_map) {
114 bool already_exists = false;
115 for (const string& input : node->input()) {
116 if (input == new_input || AsControlDependency(input) == new_input) {
117 already_exists = true;
118 break;
119 }
120 }
121 if (!already_exists) {
122 const string ctrl_dep =
123 ConstantFolding::AddControlDependency(new_input, graph, node_map);
124 node->add_input(ctrl_dep);
125 node_map->AddOutput(NodeName(new_input), node->name());
126 }
127 return !already_exists;
128 }
129
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)130 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
131 (*node->mutable_attr())[attr_name].set_type(dtype);
132 }
133
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)134 NodeDef* GetTailOfValuePreservingChain(
135 const NodeDef& node, const NodeMap& node_map,
136 const std::unordered_set<string>& nodes_to_preserve) {
137 auto is_value_preserving_non_branching = [&](const NodeDef& node) {
138 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
139 IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
140 };
141 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
142 is_value_preserving_non_branching);
143 }
144
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)145 NodeDef* GetTailOfIdempotentChain(
146 const NodeDef& node, const NodeMap& node_map,
147 const std::unordered_set<string>& nodes_to_preserve) {
148 auto is_idempotent_non_branching = [&](const NodeDef& node) {
149 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
150 IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
151 };
152 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
153 is_idempotent_non_branching);
154 }
155
156 // GetElementUnexhaustive tries to get the value of an element in a tensor and
157 // turn it into complex128 type. It only check for a limited number of data
158 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)159 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
160 complex128* element) {
161 if (dtypes.find(t.dtype()) == dtypes.end()) return false;
162 switch (t.dtype()) {
163 case DT_BFLOAT16:
164 *element = complex128(t.flat<bfloat16>()(i));
165 return true;
166 case DT_HALF:
167 *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
168 return true;
169 case DT_INT32:
170 *element = complex128(t.flat<int32>()(i));
171 return true;
172 case DT_INT64:
173 *element = complex128(t.flat<int64>()(i));
174 return true;
175 case DT_FLOAT:
176 *element = complex128(t.flat<float>()(i));
177 return true;
178 case DT_DOUBLE:
179 *element = complex128(t.flat<double>()(i));
180 return true;
181 case DT_COMPLEX64:
182 *element = complex128(t.flat<complex64>()(i));
183 return true;
184 case DT_COMPLEX128:
185 *element = t.flat<complex128>()(i);
186 return true;
187 default:
188 return false;
189 }
190 }
191
192 // Graph optimizer context extension specific to ArithmeticOptimizer.
193 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anonbb66c6470111::ArithmeticOptimizerContext194 explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
195 : nodes_to_simplify(nodes_to_simplify) {}
196 SetVector<NodeDef*>* nodes_to_simplify;
197 };
198
199 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
200 // AddOps optimization, etc...
201 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
202 public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)203 explicit ArithmeticOptimizerStage(const string& name,
204 const GraphOptimizerContext& ctx,
205 const ArithmeticOptimizerContext ctx_ext)
206 : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
207 ctx_ext_(ctx_ext) {}
208 ~ArithmeticOptimizerStage() override = default;
209
210 protected:
211 // Simplification graph rewrite can create additional nodes that are inputs
212 // to final simplified node, they can be also added to the arithmetic
213 // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)214 void AddToOptimizationQueue(NodeDef* node) {
215 ctx_ext_.nodes_to_simplify->PushBack(node);
216 }
217
218 // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
219 // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)220 void ForwardControlDependencies(
221 NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
222 for (const auto& src : src_nodes) {
223 for (int i = src->input_size() - 1; i >= 0; --i) {
224 if (IsControlInput(src->input(i))) {
225 *target_node->add_input() = src->input(i);
226 ctx().node_map->AddOutput(NodeName(src->input(i)),
227 target_node->name());
228 } else {
229 break;
230 }
231 }
232 }
233 DedupControlInputs(target_node);
234 }
235
IsInPreserveSet(const NodeDef & node) const236 bool IsInPreserveSet(const NodeDef& node) const {
237 return ctx().nodes_to_preserve->find(node.name()) !=
238 ctx().nodes_to_preserve->end();
239 }
240
241 // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const242 bool IsDrivenByControlDependency(const NodeDef& node) const {
243 return std::any_of(
244 node.input().begin(), node.input().end(),
245 [](const string& input) { return IsControlInput(input); });
246 }
247
248 // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const249 bool DrivesControlDependency(const NodeDef& node) const {
250 for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
251 for (int i = 0; i < output->input_size(); ++i) {
252 const TensorId tensor = ParseTensorName(output->input(i));
253 if (tensor.node() == node.name() && tensor.index() < 0) {
254 return true;
255 }
256 }
257 }
258 return false;
259 }
260
261 private:
262 // Extended context required for ArithmeticOptimizer.
263 const ArithmeticOptimizerContext ctx_ext_;
264 };
265
266 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
267 // group of nodes from the optimized graph.
268 //
269 // * AddOpsRewrite:
270 // Rewrite a group of Add/AddN with compact Add/AddN tree
271 //
272 // * MinimizeBroadcasts:
273 // Rewrite a group of binary associative ops, reordering
274 // inputs, to minimize the cost of broadcast
275 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
276 public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)277 explicit ArithmeticNodesGroupOptimizerStage(
278 const string& name, const GraphOptimizerContext& ctx,
279 const ArithmeticOptimizerContext ctx_ext)
280 : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
281 ~ArithmeticNodesGroupOptimizerStage() override = default;
282
283 // Input name with a statically inferred shape from GraphProperties
284 struct InputAndShape {
InputAndShapetensorflow::grappler::__anonbb66c6470111::ArithmeticNodesGroupOptimizerStage::InputAndShape285 InputAndShape(const string& input, const TensorShapeProto& shape)
286 : input(input), shape(shape) {}
287 string input;
288 TensorShapeProto shape;
289 };
290
291 // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
292 // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
293 // obtained by graph traversal, starting from a root node.
294 struct OptimizedNodesGroup {
295 NodeDef* root_node;
296 TensorShapeProto root_shape;
297 // Optimized nodes that will be updated or removed by rewrite
298 std::vector<NodeDef*> optimized_nodes;
299 // Inputs to optimized nodes
300 std::vector<InputAndShape> inputs;
301 };
302
TrySimplify(NodeDef * node,string * simplified_node_name)303 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
304 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
305
306 OptimizedNodesGroup group;
307 TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
308
309 if (!group.optimized_nodes.empty()) {
310 *simplified_node_name = RewriteOptimizedNodesGroup(group);
311 }
312
313 return Status::OK();
314 }
315
316 protected:
317 // Modify the optimized graph after nodes group was successfully identified
318 virtual string RewriteOptimizedNodesGroup(
319 const OptimizedNodesGroup& group) = 0;
320
321 // Check if input can become a part of current optimized nodes group.
322 virtual bool IsAbsorbableByOptimizedNodesGroup(
323 const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
324
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const325 Status AbsorbInputByOptimizedNodesGroup(const string& input,
326 OptimizedNodesGroup* group) const {
327 std::deque<const string*> input_tensors;
328 input_tensors.push_front(&input);
329
330 while (!input_tensors.empty()) {
331 const string* input_tensor = input_tensors.front();
332 input_tensors.pop_front();
333
334 // Get a node for the input tensor.
335 NodeDef* input_node;
336 TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
337
338 if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
339 group->optimized_nodes.push_back(input_node);
340 for (int i = input_node->input_size() - 1; i >= 0; --i) {
341 const string& absorbed_node_input = input_node->input(i);
342 // TODO(ezhulenev): support control inputs
343 if (IsControlInput(absorbed_node_input)) continue;
344 input_tensors.push_front(&absorbed_node_input);
345 }
346 } else {
347 // If input node can't be absorbed, add it to OptimizedNodesGroup input.
348 OpInfo::TensorProperties properties;
349 TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
350 group->inputs.emplace_back(*input_tensor, properties.shape());
351 }
352 }
353
354 return Status::OK();
355 }
356
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const357 Status CreateOptimizedNodesGroup(NodeDef* root_node,
358 OptimizedNodesGroup* group) const {
359 OpInfo::TensorProperties root_node_output_properties;
360 TF_RETURN_IF_ERROR(
361 GetTensorProperties(root_node->name(), &root_node_output_properties));
362
363 group->root_node = root_node;
364 group->root_shape = root_node_output_properties.shape();
365
366 group->optimized_nodes.reserve(root_node->input_size());
367 for (int i = 0; i < root_node->input_size(); ++i) {
368 const string& input_i = root_node->input(i);
369 // TODO(ezhulenev): add support for control inputs
370 if (IsControlInput(input_i)) continue;
371 TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
372 }
373
374 return Status::OK();
375 }
376
377 // Check if all inputs can be broadcasted to the same shape
378 // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const379 bool HasAllInputsBroadcastableToShape(
380 const NodeDef& node, const OpInfo::TensorProperties& properties) const {
381 auto is_broadcastable = [this, &properties](const string& input) {
382 OpInfo::TensorProperties input_props;
383 Status has_input_properties = GetTensorProperties(input, &input_props);
384 return has_input_properties.ok() &&
385 ShapesBroadcastable(properties, input_props);
386 };
387 return std::all_of(node.input().begin(), node.input().end(),
388 is_broadcastable);
389 }
390
ShapeSignature(const TensorShapeProto & shape) const391 string ShapeSignature(const TensorShapeProto& shape) const {
392 string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
393 for (int i = 0; i < shape.dim_size(); ++i)
394 strings::StrAppend(&signature, ":", shape.dim(i).size());
395 return signature;
396 }
397
MarkWithTag(const StringPiece tag,NodeDef * node)398 void MarkWithTag(const StringPiece tag, NodeDef* node) {
399 AddNodeAttr(tag, true, node);
400 }
401
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const402 void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
403 const StringPiece tag) const {
404 AddNodeAttr(tag, true, group.root_node);
405 for (NodeDef* optimized_node : group.optimized_nodes) {
406 AddNodeAttr(tag, true, optimized_node);
407 }
408 }
409
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const410 bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
411 const NodeDef& node) const {
412 return group.root_node->device() == node.device();
413 }
414
IsInPreserveSet(const NodeDef & node) const415 bool IsInPreserveSet(const NodeDef& node) const {
416 return ctx().nodes_to_preserve->find(node.name()) !=
417 ctx().nodes_to_preserve->end();
418 }
419
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const420 bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
421 return HasNodeAttr(node, tag);
422 }
423
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const424 bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
425 const StringPiece tag2) const {
426 return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
427 }
428 };
429
430 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
431 // original inputs of absorbed nodes.
432 //
433 // 1) All nodes must have the same device placement.
434 //
435 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
436 // optimized to a single AddN node.
437 //
438 // AddN_1
439 // / | \
440 // Add_1 z Add_2 -> AddN(x, y, z, w, q, e)
441 // / \ / \
442 // x y w Add_3
443 // / \
444 // q e
445 //
446 // 3) If some nodes have different shape (it needs to be broadcastable to the
447 // shape of a "root), tree is optimized to AddNs for symbolically equal
448 // shapes, and a tree of Add ops, that minimize broadcasts.
449 //
450 // AddN_1 Add
451 // / | \ / \
452 // Add_1 z Add_2 -> Add w
453 // / \ / \ / \
454 // x y w Add_3 AddN(x, y, q, e) z
455 // / \
456 // q e
457 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
458 public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)459 explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
460 const ArithmeticOptimizerContext& ctx_ext)
461 : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
462 ~AddOpsRewriteStage() override = default;
463
464 // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const465 bool IsSupported(const NodeDef* node) const override {
466 if (!CanOptimize(*node)) return false;
467
468 // shape must be symbolically defined and all inputs compatible with it
469 OpInfo::TensorProperties properties;
470 Status has_properties = GetTensorProperties(node->name(), &properties);
471 return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
472 HasAllInputsBroadcastableToShape(*node, properties);
473 }
474
475 protected:
476 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const477 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
478 const NodeDef& node) const override {
479 if (!CanOptimize(node)) return false;
480
481 if (!IsOnTheSameDevice(group, node)) {
482 return false;
483 }
484 // with a single output data consumer (presumably if we reach this node from
485 // previously absorbed or a root node, it means that this node is not used
486 // as an input to any other op, outside of the group)
487 if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
488 return false;
489 }
490 // All input shapes must be broadcastable to the node shape
491 OpInfo::TensorProperties properties;
492 Status has_properties = GetTensorProperties(node.name(), &properties);
493 return has_properties.ok() &&
494 HasAllInputsBroadcastableToShape(node, properties);
495 }
496
497 // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const498 bool CanOptimize(const NodeDef& node) const {
499 // TODO(ezhulenev): check if AccumulateNV2 can be supported too
500 if (!IsAdd(node) && !IsAddN(node)) {
501 return false;
502 }
503 if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
504 return false;
505 }
506 // TODO(ezhulenev): relax this condition for root node
507 return !(IsDrivenByControlDependency(node) ||
508 DrivesControlDependency(node));
509 }
510
511 // Rewrite a group of add ops into a single AddN if all input shapes are
512 // symbolically equal. If not, create AddN for equal shapes first, and then
513 // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)514 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
515 VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
516 << " op=" << group.root_node->op()
517 << " num_optimized_nodes=" << group.optimized_nodes.size()
518 << " num_inputs=" << group.inputs.size();
519
520 // Do not optimize any of the nodes that are part of this group.
521 MarkAllMembersWithTag(group, kAddOpsRewriteTag);
522
523 // All new nodes will be placed under the scope of a root node.
524 auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
525
526 // Find what shapes are present in the inputs of absorbed nodes.
527 std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
528 for (const auto& input : group.inputs) {
529 shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
530 }
531
532 using SigKV = decltype(shape_sig_to_inputs)::value_type;
533 VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
534 << " unique shapes: "
535 << str_util::Join(shape_sig_to_inputs, ", ",
536 [](string* out, SigKV p) {
537 strings::StrAppend(out, p.first);
538 });
539
540 // Collect all the shapes from representative elements.
541 std::vector<TensorShapeProto> shapes;
542 shapes.reserve(shape_sig_to_inputs.size());
543 for (const auto& el : shape_sig_to_inputs)
544 shapes.push_back(el.second[0].shape);
545
546 // If all inputs have the same shape, rewrite whole group with a single AddN
547 if (shapes.size() == 1) {
548 string node_name = UniqueOptimizedNodeName(root_scope_and_name);
549 AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
550 group.inputs);
551 return node_name;
552 }
553
554 // For inputs of different shapes:
555 // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
556 // 2. Build a tree of Add nodes, minimizing cost of broadcast
557 std::sort(shapes.begin(), shapes.end(),
558 [](const TensorShapeProto& left, const TensorShapeProto& right) {
559 return CompareSymbolicallyShapedTensorSizes(left, right);
560 });
561
562 // optimized name for leaf AddN nodes
563 auto leaf_node_name = [&root_scope_and_name, this](int i) {
564 return UniqueOptimizedNodeName(root_scope_and_name,
565 strings::StrCat("Leaf_", i));
566 };
567 // optimized name for internal nodes of a tree built up from AddN leaves
568 auto internal_node_name = [&root_scope_and_name, this](int i) {
569 return UniqueOptimizedNodeName(root_scope_and_name,
570 strings::StrCat("Internal_", i));
571 };
572
573 // Add/AddN nodes that must be added to the tree
574 std::deque<InputAndShape> add_ops;
575
576 // Prepare leaf AddN nodes for inputs of equal shape
577 for (int i = 0; i < shapes.size(); ++i) {
578 const auto node_name = leaf_node_name(i);
579 const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
580 add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
581 node_name, inputs));
582 }
583
584 // Build up a tree of Add ops
585 int internal_nodes = 0;
586 do {
587 const InputAndShape lhs = add_ops.front();
588 add_ops.pop_front();
589 const InputAndShape rhs = add_ops.front();
590 add_ops.pop_front();
591 string name = add_ops.empty()
592 ? UniqueOptimizedNodeName(root_scope_and_name)
593 : internal_node_name(internal_nodes++);
594 InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
595 add_ops.push_front(add);
596 } while (add_ops.size() > 1);
597
598 InputAndShape optimized_root_node = add_ops.front();
599 return optimized_root_node.input;
600 }
601
602 // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)603 InputAndShape AddInputsOfSymbolicallyEqualShape(
604 const NodeDef& root_node, const string& node_name,
605 const std::vector<InputAndShape>& inputs) {
606 CHECK(!inputs.empty()) << "Inputs must be non-empty";
607
608 // Do not create redundant AddN nodes
609 if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
610 return inputs[0];
611 }
612
613 // get shape from representative element
614 auto shape = inputs[0].shape;
615
616 // copy attributes from a root node
617 DataType dtype = root_node.attr().at("T").type();
618
619 // add new AddN node
620 NodeDef* node = AddEmptyNode(node_name);
621 node->set_op("AddN");
622 node->set_device(root_node.device());
623 (*node->mutable_attr())["T"].set_type(dtype);
624 (*node->mutable_attr())["N"].set_i(inputs.size());
625
626 for (const auto& inputAndShape : inputs) {
627 ctx().node_map->AddOutput(inputAndShape.input, node_name);
628 node->add_input(inputAndShape.input);
629 }
630
631 MarkWithTag(kAddOpsRewriteTag, node);
632 return InputAndShape(node_name, shape);
633 }
634
635 // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)636 InputAndShape AddAggregatedInputs(const NodeDef& root_node,
637 const string& node_name,
638 const InputAndShape& left,
639 const InputAndShape& right) {
640 // copy attributes from a root node
641 DataType dtype = root_node.attr().at("T").type();
642
643 // add new Add node
644 NodeDef* node = AddEmptyNode(node_name);
645 node->set_op("Add");
646 node->set_device(root_node.device());
647 (*node->mutable_attr())["T"].set_type(dtype);
648 node->add_input(left.input);
649 node->add_input(right.input);
650
651 ctx().node_map->AddOutput(left.input, node_name);
652 ctx().node_map->AddOutput(right.input, node_name);
653
654 MarkWithTag(kAddOpsRewriteTag, node);
655 return InputAndShape(
656 node_name, TensorShapeProto()); // shape is not important at this point
657 }
658 };
659
660 // Use the distributive property of multiplication and division over addition,
661 // along with commutativity of the former, to hoist common factors/denominators
662 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
663 // This pattern occurs frequently in regularization terms for the gradients
664 // during training.
665 //
666 // For example, we can rewrite an expression of the form:
667 // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
668 // to the following:
669 // Mul(x, AddN(y1, y2, y3, ... yn))
670 // For division, we can rewrite
671 // AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
672 // to:
673 // Div(AddN(y1, y2, y3, ... yn), x)
674 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
675 public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)676 explicit HoistCommonFactorOutOfAggregation(
677 const GraphOptimizerContext& ctx,
678 const ArithmeticOptimizerContext& ctx_ext)
679 : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
680 ~HoistCommonFactorOutOfAggregation() override = default;
681
IsSupported(const NodeDef * node) const682 bool IsSupported(const NodeDef* node) const override {
683 return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
684 !IsRewritten(node);
685 }
686
TrySimplify(NodeDef * node,string * simplified_node_name)687 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
688 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
689
690 bool common_factor_is_denominator = false;
691 std::set<string> common_factors;
692 std::vector<string> ctrl_deps;
693 TF_RETURN_IF_ERROR(GetCommonFactors(
694 node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
695
696 if (common_factors.size() == 1) {
697 const string& common_factor = *common_factors.begin();
698
699 // Gather up the non-shared factors
700 bool shapes_match = true;
701 std::vector<string> unique_factors;
702 TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
703 common_factor_is_denominator,
704 &shapes_match, &unique_factors));
705
706 if (shapes_match) {
707 NodeDef* input_0;
708 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
709
710 // Use a copy of the first node for the outer multiplication/division.
711 NodeDef* new_outer_node = AddCopyNode(
712 OuterNodeName(node, common_factor_is_denominator), input_0);
713 // And a copy of aggregation node as one of the inner operands
714 NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
715
716 new_outer_node->set_device(node->device());
717 if (common_factor_is_denominator) {
718 new_outer_node->set_input(0, new_add_node->name());
719 new_outer_node->set_input(1, common_factor);
720 } else {
721 new_outer_node->set_input(0, common_factor);
722 new_outer_node->set_input(1, new_add_node->name());
723 }
724
725 ctx().node_map->AddOutput(common_factor, new_outer_node->name());
726 ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
727
728 // Hoist non-shared factors up into the new AddN node.
729 for (int i = 0; i < unique_factors.size(); ++i) {
730 const string& unique_factor_i = unique_factors[i];
731 new_add_node->set_input(i, unique_factor_i);
732 ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
733 }
734
735 // Add control deps on add node
736 for (const string& ctrl_dep : ctrl_deps) {
737 *new_add_node->add_input() = ctrl_dep;
738 ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
739 }
740
741 // optimize new inner aggregation node
742 AddToOptimizationQueue(new_add_node);
743 // do not optimize the same node twice
744 rewritten_nodes_.insert(node->name());
745 *simplified_node_name = new_outer_node->name();
746 }
747 }
748 return Status::OK();
749 }
750
751 private:
752 // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const753 string OuterNodeName(const NodeDef* node, bool is_div) const {
754 auto scope_and_name = ParseNodeScopeAndName(node->name());
755 return is_div ? OptimizedNodeName(scope_and_name, "Div")
756 : OptimizedNodeName(scope_and_name, "Mul");
757 }
758
759 // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const760 string InnerAddNodeName(const NodeDef* node) const {
761 auto scope_and_name = ParseNodeScopeAndName(node->name());
762 return OptimizedNodeName(scope_and_name, "Add");
763 }
764
765 // Determine the set of common factors if the input nodes are all Mul or
766 // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const767 Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
768 bool* common_factor_is_denominator,
769 std::vector<string>* ctrl_deps) const {
770 CHECK(common_factors->empty());
771 CHECK_NOTNULL(common_factor_is_denominator);
772 *common_factor_is_denominator = false;
773
774 bool has_mul = false;
775 bool has_div = false;
776 for (int i = 0; i < node->input_size(); ++i) {
777 if (i > 0 && common_factors->empty()) break;
778 if (IsControlInput(node->input(i))) {
779 ctrl_deps->push_back(node->input(i));
780 continue;
781 }
782 NodeDef* input;
783 TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
784
785 if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
786 (IsAnyDiv(*input) && has_mul)) {
787 // Break if input is neither a Mul or Div, or if there are both Mul &
788 // Div Ops.
789 common_factors->clear();
790 break;
791 } else if (IsAnyDiv(*input)) {
792 has_div = true;
793 // In case of possible common dividers, we avoid hoisting out if any
794 // input is not float/double, since integer division is not distributive
795 // over addition.
796 OpInfo::TensorProperties properties0, properties1;
797 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
798 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
799 if (properties0.dtype() != DT_FLOAT &&
800 properties0.dtype() != DT_DOUBLE &&
801 properties1.dtype() != DT_FLOAT &&
802 properties1.dtype() != DT_DOUBLE) {
803 common_factors->clear();
804 break;
805 }
806 } else if (IsMul(*input)) {
807 has_mul = true;
808 }
809
810 // We only focus on common factors from denominators if any Op is a
811 // Div.
812 std::set<string> factors_i =
813 has_mul ? std::set<string>{input->input(0), input->input(1)}
814 : std::set<string>{input->input(1)};
815 if (i == 0) {
816 std::swap(*common_factors, factors_i);
817 } else {
818 std::set<string> intersection;
819 std::set_intersection(
820 factors_i.begin(), factors_i.end(), common_factors->begin(),
821 common_factors->end(),
822 std::inserter(intersection, intersection.begin()));
823 std::swap(*common_factors, intersection);
824 }
825 for (int i = 2; i < input->input_size(); ++i) {
826 ctrl_deps->push_back(input->input(i));
827 }
828 }
829
830 *common_factor_is_denominator = has_div;
831 return Status::OK();
832 }
833
834 // Gather up the non-shared factors (the y's in the example).
835 // Unless the aggregation is Add, we have to make sure that all the y's
836 // have the same shape since the other aggregation ops do not support
837 // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const838 Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
839 const bool common_factor_is_denominator,
840 bool* shapes_match,
841 std::vector<string>* unique_factors) const {
842 *shapes_match = true;
843 unique_factors->reserve(node->input_size());
844
845 for (int i = 0; i < node->input_size() && shapes_match; ++i) {
846 const string& input = node->input(i);
847 if (IsControlInput(input)) {
848 break;
849 }
850 NodeDef* inner_node;
851 TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
852 const int unique_factor_index =
853 common_factor_is_denominator
854 ? 0
855 : (inner_node->input(0) == common_factor ? 1 : 0);
856 unique_factors->push_back(inner_node->input(unique_factor_index));
857 if (i > 0 && !IsAdd(*node)) {
858 OpInfo::TensorProperties lhs;
859 OpInfo::TensorProperties rhs;
860 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
861 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
862 *shapes_match = ShapesSymbolicallyEqual(lhs, rhs);
863 }
864 }
865 return Status::OK();
866 }
867
IsRewritten(const NodeDef * node) const868 bool IsRewritten(const NodeDef* node) const {
869 // if graph rewrite happens in multiple passes without graph pruning between
870 // them, it's possible that rewritten node already exists in a graph
871 return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
872 ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
873 ctx().node_map->NodeExists(OuterNodeName(node, true));
874 }
875
876 // keep names of the nodes that were optimized by this stage
877 std::unordered_set<string> rewritten_nodes_;
878 };
879
880 // Binary associative ops can be re-ordered to minimize the number of broadcasts
881 // and the size of a temporary tensors.
882 //
883 // Example: [a, c] - scalars, [b, d] - matrices
884 // @ - binary associative op (Add or Mul)
885 // @* - broadcast
886 //
887 // @ @*
888 // / \ / \
889 // @* @* -> @ @
890 // / \ / \ / \ / \
891 // a b c d a c b d
892 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
893 public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)894 explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
895 const ArithmeticOptimizerContext& ctx_ext)
896 : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
897 }
898 ~MinimizeBroadcasts() override = default;
899
IsSupported(const NodeDef * node) const900 bool IsSupported(const NodeDef* node) const override {
901 if (!IsBinaryAssociative(*node)) return false;
902
903 if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
904 return false;
905
906 // has a symbolically defined shape with broadcastable inputs
907 OpInfo::TensorProperties properties;
908 Status has_properties = GetTensorProperties(node->name(), &properties);
909 return has_properties.ok() && ShapeIsSymbolicallyDefined(properties) &&
910 HasAllInputsBroadcastableToShape(*node, properties);
911 }
912
913 protected:
IsBinaryAssociative(const NodeDef & node) const914 bool IsBinaryAssociative(const NodeDef& node) const {
915 return IsMul(node) || IsAdd(node);
916 }
917
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const918 bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
919 return group.root_node->op() == node.op();
920 }
921
922 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const923 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
924 const NodeDef& node) const override {
925 if (!IsSameOp(group, node)) {
926 return false;
927 }
928 if (IsInPreserveSet(node)) {
929 return false;
930 }
931 // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
932 if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
933 return false;
934 }
935 if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
936 return false;
937 }
938 if (!IsOnTheSameDevice(group, node)) {
939 return false;
940 }
941 // Optimized nodes updated in place, and that would break the graph, if the
942 // node has multiple output consumers
943 if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
944 return false;
945 }
946 // All input shapes must be broadcastable to the node shape
947 OpInfo::TensorProperties properties;
948 Status has_properties = GetTensorProperties(node.name(), &properties);
949 return has_properties.ok() &&
950 HasAllInputsBroadcastableToShape(node, properties);
951 }
952
CountUniqueShapes(const std::vector<InputAndShape> & inputs)953 std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
954 std::set<string> sigs;
955 for (const auto& ias : inputs) {
956 sigs.insert(ShapeSignature(ias.shape));
957 }
958 return sigs.size();
959 }
960
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)961 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
962 VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
963 << " op=" << group.root_node->op()
964 << " num_optimized_nodes=" << group.optimized_nodes.size();
965
966 // Do not optimize any of the nodes that are part of this group.
967 MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
968
969 if (CountUniqueShapes(group.inputs) <= 1) {
970 VLOG(3) << "Skip min-bcast group with single unique shape";
971 // nothing to optimize when all shapes are the same
972 return group.root_node->name();
973 }
974
975 auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
976 auto num_inputs = group.inputs.size();
977 CHECK_EQ(num_nodes, num_inputs - 1)
978 << "Can't build a tree with " << num_inputs << " inputs, using "
979 << num_nodes << "binary op nodes.";
980
981 std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
982 std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
983 group.optimized_nodes.end());
984
985 // sort inputs by it's shape from smallest to largest
986 std::stable_sort(add_ops.begin(), add_ops.end(),
987 [](const InputAndShape& lhs, const InputAndShape& rhs) {
988 return CompareSymbolicallyShapedTensorSizes(lhs.shape,
989 rhs.shape);
990 });
991
992 // If there is an odd number of inputs, last one is the largest, and we want
993 // to attach it to the root node, to build a well balanced tree.
994 std::deque<InputAndShape> add_ops_leftover;
995 if (add_ops.size() % 2 != 0) {
996 add_ops_leftover.push_back(add_ops.back());
997 add_ops.pop_back();
998 }
999
1000 // At this point it's guaranteed that add_ops have even number of inputs.
1001 do {
1002 const InputAndShape lhs = add_ops.front();
1003 add_ops.pop_front();
1004 const InputAndShape rhs = add_ops.front();
1005 add_ops.pop_front();
1006
1007 NodeDef* node;
1008 if (!optimized_nodes.empty()) {
1009 // re-purpose optimized nodes to build a new tree
1010 node = optimized_nodes.back();
1011 optimized_nodes.pop_back();
1012 } else {
1013 // or use root node if none optimized nodes left
1014 node = group.root_node;
1015 }
1016 InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1017
1018 // Pushing updated node to the back of a deque will create a wide and
1019 // short tree, pushing to the front will create a tall tree. We prefer to
1020 // get a wide tree, it minimizes the potential number of temporary tensors
1021 // required to keep in memory, though sometimes we can go up to prevent
1022 // propagating a brodcast from leaves to the root. Example:
1023 //
1024 // inputs: [s, s, s, M] (s - scalar, M - matrix)
1025 // @* - op with broadcast
1026 //
1027 // (only push_back) @* (push_front first op)
1028 // / \
1029 // @* @ M
1030 // / \ / \
1031 // @ @* -> @ s
1032 // / \ / \ / \
1033 // s s s M s s
1034 if (add_ops.size() >= 2 &&
1035 CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1036 add_ops.at(1).shape)) {
1037 add_ops.push_front(updated_node);
1038 } else {
1039 add_ops.push_back(updated_node);
1040 }
1041 } while (add_ops.size() > 1);
1042 CHECK_EQ(1, add_ops.size());
1043
1044 // attach the largest tensor to the root op
1045 if (!add_ops_leftover.empty()) {
1046 const InputAndShape lhs = add_ops.front();
1047 add_ops.pop_front();
1048 const InputAndShape rhs = add_ops_leftover.front();
1049 InputAndShape updated_node =
1050 UpdateInputs(lhs.input, rhs.input, group.root_node);
1051 add_ops.push_back(updated_node);
1052 }
1053
1054 return add_ops.front().input;
1055 }
1056
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1057 InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1058 NodeDef* node) {
1059 string old_input_0 = node->input(0);
1060 string old_input_1 = node->input(1);
1061
1062 // Update inputs only if they changed
1063 if (old_input_0 != input_0 || old_input_1 != input_1) {
1064 node->set_input(0, input_0);
1065 node->set_input(1, input_1);
1066 // Invalidate node properties (shape)
1067 ctx().graph_properties->ClearOutputProperties(node->name());
1068 ctx().graph_properties->ClearInputProperties(node->name());
1069 // Update the node map
1070 ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1071 ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1072 ctx().node_map->AddOutput(NodeName(input_0), node->name());
1073 ctx().node_map->AddOutput(NodeName(input_1), node->name());
1074 // Add updated node to optimization queue
1075 AddToOptimizationQueue(node);
1076 }
1077
1078 TensorShapeProto shape; // shape is not important at this point
1079 return InputAndShape(node->name(), shape);
1080 }
1081 };
1082
1083 // Removes inverse transpose nodes
1084 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1085 public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1086 explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1087 const ArithmeticOptimizerContext& ctx_ext)
1088 : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1089 ~RemoveIdentityTranspose() override = default;
1090
IsSupported(const NodeDef * node) const1091 bool IsSupported(const NodeDef* node) const override {
1092 return IsTranspose(*node) || IsConjugateTranspose(*node);
1093 }
1094
TrySimplify(NodeDef * node,string * simplified_node_name)1095 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1096 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1097 NodeDef* tail = node;
1098 tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1099 *ctx().nodes_to_preserve);
1100 NodeDef* first_transpose;
1101 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1102
1103 NodeDef* node_perm;
1104 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1105 if (!IsConstant(*node_perm)) {
1106 return Status::OK();
1107 }
1108 std::vector<int64> node_perm_values;
1109 TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1110 if (first_transpose->op() == node->op()) {
1111 // Remove pairs of transposes that cancel each other.
1112 NodeDef* first_transpose_perm;
1113 TF_RETURN_IF_ERROR(
1114 GetInputNode(first_transpose->input(1), &first_transpose_perm));
1115 if (!IsConstant(*first_transpose_perm)) {
1116 return Status::OK();
1117 }
1118 std::vector<int64> first_transpose_perm_values;
1119 TF_RETURN_IF_ERROR(
1120 GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1121 if (AreInversePermutations(node_perm_values,
1122 first_transpose_perm_values)) {
1123 if (tail == node) {
1124 // Bypass adjacent pair.
1125 *simplified_node_name = first_transpose->input(0);
1126 } else {
1127 // Bypass pair connected through chain.
1128 tail->set_input(0, first_transpose->input(0));
1129 ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1130 first_transpose->input(0));
1131 ForwardControlDependencies(tail, {first_transpose});
1132 *simplified_node_name = node->input(0);
1133 }
1134 }
1135 } else {
1136 // Remove simple identity transposes.
1137 if (IsIdentityPermutation(node_perm_values)) {
1138 *simplified_node_name = node->input(0);
1139 }
1140 }
1141 return Status::OK();
1142 }
1143
1144 private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1145 Status GetPermutation(const NodeDef& node_perm,
1146 std::vector<int64>* perm64) const {
1147 std::vector<int> perm32;
1148 if (ValuesFromConstNode(node_perm, &perm32)) {
1149 perm64->reserve(perm32.size());
1150 for (int val : perm32) {
1151 perm64->push_back(static_cast<int64>(val));
1152 }
1153 return Status::OK();
1154 }
1155 if (ValuesFromConstNode(node_perm, perm64)) {
1156 return Status::OK();
1157 }
1158 return errors::InvalidArgument("Couldn't extract permutation from ",
1159 node_perm.name());
1160 }
1161
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1162 bool AreInversePermutations(const std::vector<int64>& a,
1163 const std::vector<int64>& b) {
1164 if (a.size() != b.size()) {
1165 return false;
1166 }
1167 for (int i = 0; i < a.size(); ++i) {
1168 if (a[b[i]] != i) {
1169 return false;
1170 }
1171 }
1172 return true;
1173 }
1174
IsIdentityPermutation(const std::vector<int64> & perm)1175 bool IsIdentityPermutation(const std::vector<int64>& perm) {
1176 for (int64 i = 0; i < perm.size(); ++i) {
1177 if (i != perm[i]) {
1178 return false;
1179 }
1180 }
1181 return true;
1182 }
1183 };
1184
1185 // An involution is an element-wise function f(x) that is its own inverse,
1186 // i.e. f(f(x)) = x. If we can find a chain of ops
1187 // f->op1->op2->...opn->f
1188 // where op1 through opn preserve the values of their inputs, we can remove
1189 // the two instances of the involution from the graph, since they cancel
1190 // each other.
1191 class RemoveInvolution : public ArithmeticOptimizerStage {
1192 public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1193 explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1194 const ArithmeticOptimizerContext& ctx_ext)
1195 : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1196 ~RemoveInvolution() override = default;
1197
IsSupported(const NodeDef * node) const1198 bool IsSupported(const NodeDef* node) const override {
1199 return IsInvolution(*node);
1200 }
1201
TrySimplify(NodeDef * node,string * simplified_node_name)1202 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1203 NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1204 *ctx().nodes_to_preserve);
1205
1206 NodeDef* involution;
1207 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1208
1209 if (involution->op() == node->op()) {
1210 // Skip both *node and *involution since they cancel each other.
1211 if (tail == node) {
1212 // The two nodes to eliminate are adjacent.
1213 *simplified_node_name = involution->input(0);
1214 } else {
1215 tail->set_input(0, involution->input(0));
1216 ctx().node_map->UpdateInput(tail->name(), involution->name(),
1217 involution->input(0));
1218 *simplified_node_name = node->input(0);
1219 }
1220 }
1221
1222 return Status::OK();
1223 }
1224 };
1225
1226 // Remove redundant Bitcasts.
1227 // 1) Remove Bitcast whose source type and destination type are equal
1228 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1229 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1230 public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1231 explicit RemoveRedundantBitcastStage(
1232 const GraphOptimizerContext& ctx,
1233 const ArithmeticOptimizerContext& ctx_ext)
1234 : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1235 ~RemoveRedundantBitcastStage() override = default;
1236
IsSupported(const NodeDef * node) const1237 bool IsSupported(const NodeDef* node) const override {
1238 return IsBitcast(*node);
1239 }
1240
TrySimplify(NodeDef * node,string * simplified_node_name)1241 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1242 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1243
1244 // Bypass Bitcast whose source type and destination type are equal.
1245 AttrSlice attrs(*node);
1246 DataType input_type;
1247 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1248 DataType output_type;
1249 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1250 if (input_type == output_type) {
1251 *simplified_node_name = node->input(0);
1252 return Status::OK();
1253 }
1254
1255 NodeDef* bitcast;
1256 TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1257 NodeDef* operand;
1258 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1259
1260 if (IsBitcast(*operand)) {
1261 AttrSlice operand_attrs(*operand);
1262 DataType operand_input_type;
1263 TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1264 // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1265 bitcast->set_input(0, operand->input(0));
1266 SetDataTypeToAttr(operand_input_type, "T", bitcast);
1267 ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1268 operand->input(0));
1269 AddToOptimizationQueue(bitcast);
1270 *simplified_node_name = bitcast->name();
1271 }
1272
1273 return Status::OK();
1274 }
1275 };
1276
1277 // Remove Casts whose source type and destination type are equal.
1278 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1279 public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1280 explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1281 const ArithmeticOptimizerContext& ctx_ext)
1282 : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1283 ~RemoveRedundantCastStage() override = default;
1284
IsSupported(const NodeDef * node) const1285 bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
1286
TrySimplify(NodeDef * node,string * simplified_node_name)1287 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1288 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1289
1290 // Bypass Cast whose source type and destination type are equal.
1291 AttrSlice attrs(*node);
1292 DataType input_type;
1293 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1294 DataType output_type;
1295 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1296 if (input_type == output_type) {
1297 *simplified_node_name = node->input(0);
1298 }
1299 return Status::OK();
1300 }
1301 };
1302
1303 class RemoveNegationStage : public ArithmeticOptimizerStage {
1304 public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1305 explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1306 const ArithmeticOptimizerContext& ctx_ext)
1307 : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1308 ~RemoveNegationStage() override = default;
1309
IsSupported(const NodeDef * node) const1310 bool IsSupported(const NodeDef* node) const override {
1311 return IsAdd(*node) || IsSub(*node);
1312 }
1313
TrySimplify(NodeDef * node,string * simplified_node_name)1314 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1315 NodeDef* x;
1316 NodeDef* y;
1317 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1318 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1319 bool updated = false;
1320 if (IsNeg(*y)) {
1321 // a - (-b) = a + b or a + (-b) = a - b
1322 ForwardControlDependencies(node, {y});
1323 ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1324 node->set_op(IsAdd(*node) ? "Sub" : "Add");
1325 node->set_input(1, y->input(0));
1326 updated = true;
1327 } else if (IsAdd(*node) && IsNeg(*x)) {
1328 // (-a) + b = b - a
1329 ForwardControlDependencies(node, {x});
1330 ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1331 node->set_op("Sub");
1332 node->mutable_input()->SwapElements(0, 1);
1333 node->set_input(1, x->input(0));
1334 updated = true;
1335 }
1336 if (updated) {
1337 AddToOptimizationQueue(node);
1338 }
1339 return Status::OK();
1340 }
1341 };
1342
1343 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1344 public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1345 explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1346 const ArithmeticOptimizerContext& ctx_ext)
1347 : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1348 ~RemoveLogicalNotStage() override = default;
1349
IsSupported(const NodeDef * node) const1350 bool IsSupported(const NodeDef* node) const override {
1351 return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1352 }
1353
TrySimplify(NodeDef * node,string * simplified_node_name)1354 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1355 const string node_name = node->name();
1356 NodeDef* input;
1357 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1358 if (IsInPreserveSet(*input) ||
1359 NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1360 return Status::OK();
1361 }
1362 string new_op;
1363 if (IsEqual(*input)) {
1364 new_op = "NotEqual";
1365 } else if (IsNotEqual(*input)) {
1366 new_op = "Equal";
1367 } else if (IsLess(*input)) {
1368 new_op = "GreaterEqual";
1369 } else if (IsLessEqual(*input)) {
1370 new_op = "Greater";
1371 } else if (IsGreater(*input)) {
1372 new_op = "LessEqual";
1373 } else if (IsGreaterEqual(*input)) {
1374 new_op = "Less";
1375 }
1376 if (!new_op.empty()) {
1377 input->set_op(new_op);
1378 *simplified_node_name = input->name();
1379 }
1380 return Status::OK();
1381 }
1382 };
1383
1384 // This optimization hoists the common prefix of unary ops of the inputs to
1385 // concat out of the concat, for example:
1386 // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1387 // becomes
1388 // Exp(Sin(Concat([x, y, z]))).
1389 // Similarly, it will hoist the common postfix of unary ops into Split or
1390 // SplitV nodes, for example:
1391 // [Exp(Sin(y)) for y in Split(x)]
1392 // becomes
1393 // [y for y in Split(Exp(Sin(x))]
1394 //
1395 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1396 // on the concat/split node.
1397 // TODO(rmlarsen): Handle Enter/Exit.
1398 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1399 public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1400 explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1401 const ArithmeticOptimizerContext& ctx_ext)
1402 : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1403
1404 ~HoistCWiseUnaryChainsStage() override = default;
1405
1406 struct ChainLink {
1407 ChainLink() = default;
ChainLinktensorflow::grappler::__anonbb66c6470111::HoistCWiseUnaryChainsStage::ChainLink1408 ChainLink(NodeDef* _node, int _port_origin)
1409 : node(_node), port_origin(_port_origin) {}
1410 NodeDef* node; // Node in a chain.
1411 int port_origin; // Port on concat/split node from which this chain
1412 // originates.
1413
operator <tensorflow::grappler::__anonbb66c6470111::HoistCWiseUnaryChainsStage::ChainLink1414 bool operator<(const ChainLink& other) const {
1415 if (port_origin < other.port_origin) {
1416 return true;
1417 } else if (port_origin > other.port_origin) {
1418 return false;
1419 } else {
1420 return node->name() < other.node->name();
1421 }
1422 }
1423 };
1424
1425 // We use an ordinary set sorted on port and node name, so the order, and
1426 // hence the node name used for the hoisted chain, will be deterministic.
1427 using ChainLinkSet = std::set<ChainLink>;
1428
IsSupported(const NodeDef * node) const1429 bool IsSupported(const NodeDef* node) const override {
1430 if (IsInPreserveSet(*node)) return false;
1431 if (IsConcat(*node) && node->attr().count("N") != 0) {
1432 const int n = node->attr().at("N").i();
1433 return n > 1;
1434 } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1435 node->attr().count("num_split") != 0) {
1436 const int num_split = node->attr().at("num_split").i();
1437 if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1438 // TODO(rmlarsen): Remove this constraint when we have optimizations
1439 // in place for merging slices into splits.
1440 return false;
1441 }
1442 return num_split > 1 && !IsAlreadyOptimized(*node);
1443 }
1444 return false;
1445 }
1446
TrySimplify(NodeDef * node,string * simplified_node_name)1447 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1448 node_is_concat_ = IsConcat(*node);
1449 int prefix_length;
1450 std::set<string> ctrl_inputs;
1451 ChainLinkSet tails;
1452 TF_RETURN_IF_ERROR(
1453 FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1454 if (prefix_length > 0 && !tails.empty()) {
1455 TF_RETURN_IF_ERROR(
1456 HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1457 }
1458 return Status::OK();
1459 }
1460
1461 private:
1462 // Returns the length of the common unary chain of ops that can be
1463 // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1464 Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1465 ChainLinkSet* tails,
1466 std::set<string>* ctrl_inputs) const {
1467 *prefix_length = 0;
1468 // Follow the chains starting at each concat input or split output as long
1469 // as all the following conditions hold:
1470 // 1. The ops in all chains are the same.
1471 // 2. The ops are unary elemenwise op.
1472 // 3. The op output has only a single consumer (concat only).
1473 ChainLinkSet cur_tails;
1474 TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1475 if (cur_tails.size() < 2) {
1476 return Status::OK();
1477 }
1478 ctrl_inputs->clear();
1479 bool stop = false;
1480 while (!stop && !cur_tails.empty() &&
1481 OpsAreSafeToHoist(root_node, cur_tails)) {
1482 // We found one more link that can be hoisted.
1483 ++(*prefix_length);
1484 tails->swap(cur_tails);
1485 GatherControlInputs(ctrl_inputs, *tails);
1486
1487 // Advance tail pointers to the next level.
1488 TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1489 }
1490 return Status::OK();
1491 }
1492
1493 // Hoists the chains to the other side of concat or split and attaches the
1494 // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1495 Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1496 std::set<string>* ctrl_inputs, NodeDef* root_node) {
1497 if (tails.empty()) {
1498 return Status::OK();
1499 }
1500 AddToOptimizationQueue(root_node);
1501 optimized_nodes_.insert(root_node->name());
1502 if (node_is_concat_) {
1503 AddControlInputs(ctrl_inputs, root_node);
1504 return HoistChainForConcat(prefix_length, tails, root_node);
1505 } else {
1506 return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1507 }
1508 }
1509
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1510 void GatherControlInputs(std::set<string>* ctrl_inputs,
1511 const ChainLinkSet& ops) const {
1512 for (const auto& link : ops) {
1513 const NodeDef* node = link.node;
1514 for (int i = node->input_size() - 1; i >= 0; --i) {
1515 const string& input = node->input(i);
1516 if (!IsControlInput(input)) break;
1517 ctrl_inputs->insert(input);
1518 }
1519 }
1520 }
1521
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1522 void AddControlInputs(std::set<string>* new_ctrl_inputs,
1523 NodeDef* node) const {
1524 for (int i = node->input_size() - 1; i >= 0; --i) {
1525 const string& existing_input = node->input(i);
1526 if (!IsControlInput(existing_input)) break;
1527 new_ctrl_inputs->erase(existing_input);
1528 }
1529 for (const string& new_input : *new_ctrl_inputs) {
1530 ctx().node_map->AddOutput(NodeName(new_input), node->name());
1531 node->add_input(new_input);
1532 }
1533 }
1534
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1535 Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1536 if (node_is_concat_) {
1537 // Handle concat nodes by looking backwards in the graph.
1538 TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1539 const int n = node.attr().at("N").i();
1540 const int start = node.op() == "Concat" ? 1 : 0;
1541 const int end = start + n;
1542 // Set up tail pointers to point to the immediate inputs to Concat.
1543 for (int input_port = start; input_port < end; ++input_port) {
1544 if (IsControlInput(node.input(input_port))) {
1545 return errors::FailedPrecondition(
1546 "Got control input ", node.input(input_port),
1547 " where normal input was expected.");
1548 }
1549 NodeDef* tail;
1550 TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1551 tails->insert(ChainLink(tail, input_port));
1552 }
1553 return Status::OK();
1554 } else {
1555 // Handle split nodes by looking forwards in the graph.
1556 const auto& outputs = ctx().node_map->GetOutputs(node.name());
1557 for (NodeDef* output : outputs) {
1558 if (IsControlInput(output->input(0))) continue;
1559 TensorId tensor_id = ParseTensorName(output->input(0));
1560 if (tensor_id.node() == node.name()) {
1561 tails->insert(ChainLink(output, tensor_id.index()));
1562 } else {
1563 // This output node has a non-control input other than the split node,
1564 // abort.
1565 tails->clear();
1566 return Status::OK();
1567 }
1568 }
1569 }
1570 return Status::OK();
1571 }
1572
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1573 bool OpsAreSafeToHoist(const NodeDef& root_node,
1574 const ChainLinkSet& ops) const {
1575 if (ops.empty()) return true;
1576 const NodeDef* op0 = ops.begin()->node;
1577 if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1578 for (const auto& link : ops) {
1579 const NodeDef* op = link.node;
1580 if (op->device() != root_node.device() || op->op() != op0->op() ||
1581 IsInPreserveSet(*op)) {
1582 return false;
1583 }
1584 if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1585 // TODO(rmlarsen): Allow outgoing control edges.
1586 return false;
1587 }
1588 }
1589 return true;
1590 }
1591
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1592 Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1593 bool* stop) const {
1594 *stop = true;
1595 new_tails->clear();
1596 for (const auto& link : tails) {
1597 const NodeDef* tail = link.node;
1598 if (node_is_concat_) {
1599 if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1600 return Status::OK();
1601 }
1602 NodeDef* new_tail;
1603 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1604 // Remember original port.
1605 new_tails->insert(ChainLink(new_tail, link.port_origin));
1606 } else {
1607 for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1608 const TensorId tensor = ParseTensorName(new_tail->input(0));
1609 if (tensor.node() != tail->name()) {
1610 return Status::OK();
1611 }
1612 // Skip control outputs.
1613 if (tensor.index() >= 0) {
1614 // Remember original port.
1615 new_tails->insert(ChainLink(new_tail, link.port_origin));
1616 }
1617 }
1618 }
1619 }
1620 *stop = false;
1621 return Status::OK();
1622 }
1623
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1624 Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1625 NodeDef* concat_node) {
1626 const string& concat_name = concat_node->name();
1627 const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1628 for (const auto& link : tails) {
1629 NodeDef* tail = CHECK_NOTNULL(link.node);
1630 const int concat_port = link.port_origin;
1631 CHECK_GE(concat_port, 0);
1632 CHECK_LT(concat_port, concat_node->input_size());
1633 const string concat_input = concat_node->input(concat_port);
1634 // Hook the node following tail directly into the concat node.
1635 const string tail_input = tail->input(0);
1636 concat_node->set_input(concat_port, tail_input);
1637 ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1638
1639 if (concat_port == first_input) {
1640 // Update the consumers of concat to consume the end of the chain
1641 // instead.
1642 UpdateConsumers(concat_node, concat_input);
1643 // Reuse nodes in the first chain to process output of concat.
1644 tail->set_input(0, concat_name);
1645 ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1646 }
1647 }
1648 return Status::OK();
1649 }
1650
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1651 Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1652 std::set<string>* ctrl_inputs,
1653 NodeDef* split_node) {
1654 // Create a new chain before the split node to process the input tensor.
1655 const string& split_name = split_node->name();
1656 auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1657
1658 // We use the first tail node in the set as a template to get the list of
1659 // ops to apply (starting from the end).
1660 NodeDef* cur_tail = tails.begin()->node;
1661 NodeDef* cur_copy = AddCopyNode(
1662 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1663 cur_copy->clear_input();
1664
1665 // Update the split to take its input from the tail of the new chain.
1666 const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1667 const string orig_input = split_node->input(value_slot);
1668 split_node->set_input(value_slot, cur_copy->name());
1669 ctx().node_map->UpdateInput(split_node->name(), orig_input,
1670 cur_copy->name());
1671 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1672
1673 // Now walk backwards creating the rest of the chain.
1674 while (cur_tail != split_node) {
1675 NodeDef* new_copy = AddCopyNode(
1676 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1677 new_copy->clear_input();
1678 cur_copy->add_input(new_copy->name());
1679 ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1680 cur_copy = new_copy;
1681 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1682 }
1683 // Connect the original input to the head of the new chain.
1684 cur_copy->add_input(orig_input);
1685 ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1686 cur_copy->name());
1687 // Make sure all the control inputs are satisfied before running the first
1688 // node in the new chain.
1689 AddControlInputs(ctrl_inputs, cur_copy);
1690
1691 // Connect all consumers of the tail nodes directly to the
1692 // output port of Split from which the chain started.
1693 for (const auto& link : tails) {
1694 UpdateConsumers(link.node,
1695 link.port_origin == 0
1696 ? split_name
1697 : strings::StrCat(split_name, ":", link.port_origin));
1698 }
1699 return Status::OK();
1700 }
1701
1702 // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)1703 void UpdateConsumers(NodeDef* node, const string& new_input) {
1704 const string& node_name = node->name();
1705 const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
1706 for (NodeDef* consumer : consumers) {
1707 for (int i = 0; i < consumer->input_size(); ++i) {
1708 if (consumer->input(i) == node_name) {
1709 consumer->set_input(i, new_input);
1710 ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
1711 }
1712 }
1713 AddToOptimizationQueue(consumer);
1714 }
1715 }
1716
IsAlreadyOptimized(const NodeDef & node) const1717 bool IsAlreadyOptimized(const NodeDef& node) const {
1718 return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1719 }
1720
1721 private:
1722 bool node_is_concat_;
1723 std::unordered_set<string> optimized_nodes_;
1724 };
1725
1726 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1727 public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1728 explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1729 const ArithmeticOptimizerContext& ctx_ext)
1730 : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1731 ~RemoveIdempotentStage() override = default;
1732
IsSupported(const NodeDef * node) const1733 bool IsSupported(const NodeDef* node) const override {
1734 return node->input_size() == 1 && IsIdempotent(*node) &&
1735 !IsInPreserveSet(*node);
1736 }
1737
TrySimplify(NodeDef * node,string * simplified_node_name)1738 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1739 NodeDef* input;
1740 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1741 if (input->op() == node->op() && input->device() == node->device()) {
1742 *simplified_node_name = node->input(0);
1743 }
1744 return Status::OK();
1745 }
1746 };
1747
1748 // Performs the conversion:
1749 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1750 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1751 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1752 public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1753 explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1754 const ArithmeticOptimizerContext& ctx_ext)
1755 : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1756 ~SqrtDivToRsqrtMulStage() override = default;
1757
IsSupported(const NodeDef * node) const1758 bool IsSupported(const NodeDef* node) const override {
1759 return IsAnyDiv(*node);
1760 }
1761
TrySimplify(NodeDef * node,string * simplified_node_name)1762 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1763 NodeDef* y;
1764 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1765 // Optimize only if divisor is a Sqrt whose output is not being consumed
1766 // elsewhere.
1767 if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1768 (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1769 // a / sqrt(b) = a * rsqrt(b)
1770 node->set_op("Mul");
1771 y->set_op("Rsqrt");
1772 AddToOptimizationQueue(node);
1773 AddToOptimizationQueue(y);
1774 }
1775 return Status::OK();
1776 }
1777 };
1778
1779 // Performs the conversion:
1780 // Square(Sub(x, y)) => Identity(SquaredDifference(x, y))
1781 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1782 public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1783 explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1784 const ArithmeticOptimizerContext& ctx_ext)
1785 : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1786 ~FuseSquaredDiffStage() override = default;
1787
IsSupported(const NodeDef * node) const1788 bool IsSupported(const NodeDef* node) const override {
1789 return IsSquare(*node);
1790 }
1791
TrySimplify(NodeDef * node,string * simplified_node_name)1792 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1793 NodeDef* b;
1794 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1795 // Optimize only if base is a Sub whose output is not being consumed
1796 // elsewhere.
1797 if (IsSub(*b) && !IsInPreserveSet(*b) &&
1798 (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1799 node->set_op("Identity");
1800 b->set_op("SquaredDifference");
1801 AddToOptimizationQueue(node);
1802 AddToOptimizationQueue(b);
1803 }
1804 return Status::OK();
1805 }
1806 };
1807
1808 // Performs the conversion:
1809 // Log(Softmax(x)) => LogSoftmax(x)
1810 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1811 public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1812 explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1813 const ArithmeticOptimizerContext& ctx_ext)
1814 : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1815 ~LogSoftmaxStage() override = default;
1816
IsSupported(const NodeDef * node) const1817 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1818
TrySimplify(NodeDef * node,string * simplified_node_name)1819 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1820 NodeDef* x;
1821 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1822 // Optimize only if arg is a Softmax whose output is not being consumed
1823 // elsewhere.
1824 if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1825 (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1826 // Log(Softmax(x)) => LogSoftmax(Identity(x))
1827 node->set_op("LogSoftmax");
1828 x->set_op("Identity");
1829 AddToOptimizationQueue(node);
1830 AddToOptimizationQueue(x);
1831 }
1832 return Status::OK();
1833 }
1834 };
1835
1836 // Bypass redundant reshape nodes:
1837 //
1838 // Reshape Reshape <-+
1839 // ^ |
1840 // | |
1841 // Reshape becomes Reshape |
1842 // ^ |
1843 // | |
1844 // input input ---+
1845 class RemoveRedundantReshape : public ArithmeticOptimizerStage {
1846 public:
RemoveRedundantReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1847 explicit RemoveRedundantReshape(const GraphOptimizerContext& ctx,
1848 const ArithmeticOptimizerContext& ctx_ext)
1849 : ArithmeticOptimizerStage("RemoveRedundantReshape", ctx, ctx_ext) {}
1850 ~RemoveRedundantReshape() override = default;
1851
IsSupported(const NodeDef * node) const1852 bool IsSupported(const NodeDef* node) const override {
1853 return IsReshape(*node);
1854 }
1855
TrySimplify(NodeDef * node,string * simplified_node_name)1856 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1857 NodeDef* input;
1858 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1859
1860 // 1. Bypass reshape followed by reshape.
1861 if (IsReshape(*input) && !HasControlInputs(*input)) {
1862 node->set_input(0, input->input(0));
1863 ctx().node_map->UpdateInput(node->name(), input->name(), input->input(0));
1864 *simplified_node_name = node->name();
1865 AddToOptimizationQueue(node);
1866 return Status::OK();
1867 }
1868
1869 // 2. If the reshape is a no-op, forward its input to its consumers, unless
1870 // it anchors a control dependency since we want to make sure that control
1871 // dependency is triggered.
1872 if (ReshapeIsIdentity(*node) && !HasControlInputs(*node)) {
1873 *simplified_node_name = node->input(0);
1874 return Status::OK();
1875 }
1876
1877 return Status::OK();
1878 }
1879
1880 private:
1881 // Returns whether `reshape` is an identity op.
ReshapeIsIdentity(const NodeDef & reshape)1882 bool ReshapeIsIdentity(const NodeDef& reshape) {
1883 OpInfo::TensorProperties reshape_props;
1884 OpInfo::TensorProperties input_props;
1885
1886 if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
1887 !GetTensorProperties(reshape.input(0), &input_props).ok()) {
1888 return false;
1889 }
1890
1891 return ShapesSymbolicallyEqual(input_props.shape(), reshape_props.shape());
1892 }
1893 };
1894
1895 // Reorder casting and value-preserving ops if beneficial.
1896 //
1897 // Original motivation: A common pattern after the layout optimizer is
1898 // casting an uint8 NHWC image to float before transposing it to NCHW. It
1899 // is beneficial to reorder the cast and the transpose to make the transpose
1900 // process smaller amount of data. More generally, this optimization converts
1901 // Op(Cast(tensor, dst_type))
1902 // to
1903 // Cast(Op(tensor), dst_type)
1904 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
1905 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
1906 // this optimization converts
1907 // Cast(Op(tensor), dst_type)
1908 // to
1909 // Op(Cast(tensor, dst_type))
1910 // when sizeof(tensor.type) > sizeof(dst_type)
1911 //
1912 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
1913 public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1914 explicit ReorderCastLikeAndValuePreserving(
1915 const GraphOptimizerContext& ctx,
1916 const ArithmeticOptimizerContext& ctx_ext)
1917 : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
1918 ctx_ext) {}
1919 ~ReorderCastLikeAndValuePreserving() override = default;
1920
IsSupported(const NodeDef * node) const1921 bool IsSupported(const NodeDef* node) const override {
1922 return (IsValuePreserving(*node) || IsCastLike(*node)) &&
1923 !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
1924 !IsControlFlow(*node) && !IsInPreserveSet(*node);
1925 }
1926
TrySimplify(NodeDef * consumer,string * simplified_node_name)1927 Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
1928 NodeDef* producer;
1929 TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
1930 const bool producer_is_cast = IsCastLike(*producer);
1931 const bool can_optimize =
1932 !IsCheckNumerics(*producer) &&
1933 ((producer_is_cast && IsValuePreserving(*consumer)) ||
1934 (IsValuePreserving(*producer) && IsCastLike(*consumer)));
1935 if (!can_optimize || IsControlFlow(*producer) ||
1936 producer->device() != consumer->device()) {
1937 return Status::OK();
1938 }
1939
1940 const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
1941 const OpDef* cast_like_op_def = nullptr;
1942 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
1943 &cast_like_op_def));
1944 DataType cast_src_type;
1945 TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
1946 &cast_src_type));
1947 DataType cast_dst_type;
1948 TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
1949 &cast_dst_type));
1950 if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
1951 return Status::OK();
1952 } else if (producer_is_cast &&
1953 DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
1954 return Status::OK();
1955 } else if (!producer_is_cast &&
1956 DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
1957 return Status::OK();
1958 }
1959
1960 // Check that nodes were not already optimized.
1961 const string optimized_producer_name = OptimizedNodeName(
1962 ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
1963 const string optimized_consumer_name = OptimizedNodeName(
1964 ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
1965 const bool is_already_optimized =
1966 ctx().node_map->NodeExists(optimized_consumer_name) ||
1967 ctx().node_map->NodeExists(optimized_producer_name);
1968 if (is_already_optimized) {
1969 return Status::OK();
1970 }
1971
1972 // Add copies of consumer and producer in reverse order.
1973 NodeDef* input;
1974 TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
1975 // Create new producer node.
1976 NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
1977 new_producer->set_input(0, producer->input(0));
1978 ctx().node_map->AddOutput(input->name(), new_producer->name());
1979
1980 // Create new consumer node.
1981 NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
1982 new_consumer->set_input(0, new_producer->name());
1983
1984 NodeDef* new_value_preserving =
1985 producer_is_cast ? new_producer : new_consumer;
1986 const DataType new_input_type =
1987 producer_is_cast ? cast_src_type : cast_dst_type;
1988 // Update the input type of the value-preserving node. The input and
1989 // output types of the cast-like nodes remain the same.
1990 TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
1991 // Make sure there is a kernel registered for the value preserving op
1992 // with the new input type.
1993 TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
1994 ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
1995
1996 AddToOptimizationQueue(new_producer);
1997 *simplified_node_name = new_consumer->name();
1998
1999 return Status::OK();
2000 }
2001
2002 private:
2003 // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2004 Status SetInputType(DataType dtype, NodeDef* node) {
2005 const OpDef* op_def = nullptr;
2006 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2007 const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2008 const string& type_attr_name = input_arg.type_attr();
2009 if (type_attr_name.empty()) {
2010 if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2011 return errors::InvalidArgument("Could not set input type of ",
2012 node->op(), " op to ",
2013 DataTypeString(dtype));
2014 } else {
2015 // Op has fixed input type that already matches dtype.
2016 return Status::OK();
2017 }
2018 }
2019 SetDataTypeToAttr(dtype, type_attr_name, node);
2020 return Status::OK();
2021 }
2022 // This optimization can be dangerous on devices other than CPU and
2023 // GPU. The transpose might not be implemented for image.type, or
2024 // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2025 bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2026 using str_util::StrContains;
2027
2028 string task;
2029 string device;
2030
2031 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2032 (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2033 }
2034
IsFixedSizeType(DataType dtype)2035 bool IsFixedSizeType(DataType dtype) {
2036 return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2037 !kQuantizedTypes.Contains(dtype);
2038 }
2039 };
2040
2041 // Fold a multiply of a scalar into the following convolution. This folding
2042 // can jump across nodes that merely reorders data (such as reshape and
2043 // transpose). For example, we can optimize
2044 //
2045 //
2046 // Conv2D Conv2D
2047 // / \ / \
2048 // Transpose weights* -> Transpose Mul
2049 // | | / \
2050 // Mul | weights scale
2051 // / \ |
2052 // input scale** input
2053 //
2054 // *) weights must be a const
2055 // **) scale must be a const scalar
2056 //
2057 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2058 // constant-folded, also weights tend to be smaller than the activations.
2059 //
2060 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2061 // Conv?DBackpropInput.
2062 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2063 public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2064 explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2065 const ArithmeticOptimizerContext& ctx_ext)
2066 : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2067 ~FoldMultiplyIntoConv() override = default;
2068
IsSupported(const NodeDef * node) const2069 bool IsSupported(const NodeDef* node) const override {
2070 return IsConv2D(*node) || IsConv3D(*node);
2071 }
2072
TrySimplify(NodeDef * node,string * simplified_node_name)2073 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2074 #define TF_RETURN_IF_TRUE(...) \
2075 if ((__VA_ARGS__)) return Status::OK()
2076
2077 NodeDef* conv = node;
2078
2079 NodeDef* weights;
2080 TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2081
2082 // Fold the multiply to conv only when the weights are constant, so the
2083 // multiply can be constant-folded.
2084 //
2085 // TODO(jingyue): When the weights aren't constant, this should also help
2086 // performance a bit and memory usage a lot, since the weights tend to be
2087 // smaller than the activations.
2088 TF_RETURN_IF_TRUE(!IsConstant(*weights));
2089
2090 // Verify that this node was not already optimized.
2091 const string scaled_weights_node_name =
2092 OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2093 strings::StrCat("scaled", "_", conv->name()));
2094
2095 TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2096
2097 // Find the tail of value preserving chain entering the Conv node.
2098 NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2099 *ctx().nodes_to_preserve);
2100
2101 NodeDef* source;
2102 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2103
2104 // Check that value preserving chain is the only consumer of the Mul output.
2105 TF_RETURN_IF_TRUE(!IsMul(*source));
2106 TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2107
2108 const NodeDef* mul = source;
2109
2110 // TODO(jingyue): handle the case where `scale` is 0-th operand.
2111 NodeDef* scale; // scalar multiplier fot the input tensor
2112 NodeDef* input;
2113 TF_RETURN_IF_ERROR(GetInputNode(mul->input(1), &scale));
2114 TF_RETURN_IF_ERROR(GetInputNode(mul->input(0), &input));
2115
2116 // Check that 'scale * weight' can be const folded.
2117 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2118 TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype", "value"}));
2119 TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2120 TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2121 weights->attr().at("dtype").type());
2122
2123 // Check that `scale` is a scalar.
2124 const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2125 bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2126 scale_tensor.tensor_shape().dim_size() == 0;
2127 TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2128
2129 // At this point all preconditions are met, and we safely do the rewrite.
2130 VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2131 << " mul=" << mul->name() << " weights=" << weights->name();
2132
2133 // Create new node `scaled_weights`.
2134 NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2135 scaled_weights->set_op("Mul");
2136 scaled_weights->set_device(weights->device());
2137 (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2138 AddToOptimizationQueue(scaled_weights);
2139
2140 // Link in its inputs.
2141 scaled_weights->add_input(conv->input(1));
2142 ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2143 scaled_weights->add_input(mul->input(1));
2144 ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2145 ForwardControlDependencies(scaled_weights, {source});
2146
2147 // Update `conv`'s weights to `scaled_weights`.
2148 conv->set_input(1, scaled_weights->name());
2149 ctx().node_map->UpdateInput(conv->name(), weights->name(),
2150 scaled_weights->name());
2151 AddToOptimizationQueue(conv);
2152
2153 // Update `tail` node to bypass `mul` because it's folded to the weights.
2154 tail->set_input(0, mul->input(0));
2155 ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2156 AddToOptimizationQueue(tail);
2157 *simplified_node_name = conv->name();
2158
2159 return Status::OK();
2160 #undef TF_RETURN_IF_TRUE
2161 }
2162 };
2163
2164 // Fold Transpose into matrix multiplication.
2165 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2166 public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2167 explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2168 const ArithmeticOptimizerContext& ctx_ext)
2169 : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2170 ~FoldTransposeIntoMatMul() override = default;
2171
IsSupported(const NodeDef * node) const2172 bool IsSupported(const NodeDef* node) const override {
2173 return IsMatMul(*node);
2174 }
2175
TrySimplify(NodeDef * node,string * simplified_node_name)2176 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2177 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2178 const string optimized_node_name = OptimizedNodeName(matmul);
2179 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2180
2181 NodeDef* a;
2182 NodeDef* b;
2183 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2184 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2185
2186 bool is_complex = false;
2187 if (node->op() != "SparseMatMul") {
2188 const DataType type = GetDataTypeFromAttr(*node, "T");
2189 is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2190 }
2191
2192 const std::set<string> foldable_transpose_ops =
2193 !is_complex ? std::set<string>{"ConjugateTranspose", "Transpose"}
2194 : (node->op() == "BatchMatMul"
2195 ? std::set<string>{"ConjugateTranspose"}
2196 : std::set<string>{"Transpose"});
2197
2198 const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2199 IsInnerMatrixTransposeNode(*a, ctx().node_map);
2200 const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2201 IsInnerMatrixTransposeNode(*b, ctx().node_map);
2202 if (!a_is_foldable && !b_is_foldable) return Status::OK();
2203
2204 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2205
2206 if (a_is_foldable) {
2207 const string attr_a =
2208 node->op() == "BatchMatMul" ? "adj_x" : "transpose_a";
2209 FlipBooleanAttr(attr_a, new_op);
2210 new_op->set_input(0, a->input(0));
2211 ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2212 }
2213
2214 if (b_is_foldable) {
2215 const string attr_b =
2216 node->op() == "BatchMatMul" ? "adj_y" : "transpose_b";
2217 FlipBooleanAttr(attr_b, new_op);
2218 new_op->set_input(1, b->input(0));
2219 ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2220 }
2221
2222 std::vector<const NodeDef*> deps_to_forward = {node};
2223 if (a_is_foldable) deps_to_forward.push_back(a);
2224 if (b_is_foldable) deps_to_forward.push_back(b);
2225 ForwardControlDependencies(new_op, deps_to_forward);
2226
2227 return Status::OK();
2228 }
2229
2230 private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2231 void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2232 const bool old_value =
2233 !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2234 (*node->mutable_attr())[attr_name].set_b(!old_value);
2235 }
2236
2237 template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2238 bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2239 const T n = perm.size();
2240 if (n < 2) {
2241 return false;
2242 }
2243 for (T i = 0; i < n - 2; ++i) {
2244 if (perm[i] != i) {
2245 return false;
2246 }
2247 }
2248 return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2249 }
2250
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2251 bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2252 const NodeMap* node_map) {
2253 if (transpose_node.op() != "Transpose" &&
2254 transpose_node.op() != "ConjugateTranspose") {
2255 return false;
2256 }
2257 const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2258 std::vector<int> perm32;
2259 if (ValuesFromConstNode(*perm_node, &perm32)) {
2260 return IsInnerMatrixTranspose(perm32);
2261 }
2262 std::vector<int64> perm64;
2263 if (ValuesFromConstNode(*perm_node, &perm64)) {
2264 return IsInnerMatrixTranspose(perm64);
2265 }
2266 return false;
2267 }
2268 };
2269
2270 // Fold Transpose into matrix multiplication.
2271 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2272 public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2273 explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2274 const ArithmeticOptimizerContext& ctx_ext)
2275 : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2276 ~FoldConjugateIntoTranspose() override = default;
2277
IsSupported(const NodeDef * node) const2278 bool IsSupported(const NodeDef* node) const override {
2279 return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2280 }
2281
TrySimplify(NodeDef * node,string * simplified_node_name)2282 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2283 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2284 const string optimized_node_name = OptimizedNodeName(matmul);
2285 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2286
2287 NodeDef* input;
2288 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2289
2290 const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2291 const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2292
2293 if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2294 IsConj(*conj_op)) {
2295 NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2296
2297 // Flip the type of transpose op to absorb the conjugation.
2298 new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2299 : "Transpose");
2300 new_op->set_input(0, input->input(0));
2301 ctx().node_map->UpdateInput(new_op->name(), node->name(),
2302 input->input(0));
2303 ForwardControlDependencies(new_op, {node, input});
2304 *simplified_node_name = new_op->name();
2305 }
2306
2307 return Status::OK();
2308 }
2309 };
2310
2311 // Replace Mul node with identical inputs with a Square.
2312 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2313 public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2314 explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2315 const ArithmeticOptimizerContext& ctx_ext)
2316 : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2317 ~ReplaceMulWithSquare() override = default;
2318
IsSupported(const NodeDef * node) const2319 bool IsSupported(const NodeDef* node) const override {
2320 return IsMul(*node) && node->input(0) == node->input(1);
2321 }
2322
TrySimplify(NodeDef * node,string * simplified_node_name)2323 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2324 const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2325 const string optimized_node_name = OptimizedNodeName(mul);
2326 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2327
2328 const DataType type = GetDataTypeFromAttr(*node, "T");
2329 bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2330
2331 string task;
2332 string device;
2333 bool is_on_cpu =
2334 DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2335 str_util::StrContains(device, DEVICE_CPU);
2336
2337 if (!is_complex || is_on_cpu) {
2338 NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2339 new_square_node->set_op("Square");
2340 for (int i = 1; i < new_square_node->input_size(); ++i) {
2341 new_square_node->set_input(i - 1, new_square_node->input(i));
2342 }
2343 new_square_node->mutable_input()->RemoveLast();
2344 for (const string& input : new_square_node->input()) {
2345 ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2346 }
2347 *simplified_node_name = new_square_node->name();
2348 }
2349
2350 return Status::OK();
2351 }
2352 };
2353
2354 // Simplify aggregation (e.g. AddN) nodes:
2355 //
2356 // 1. Discard aggregate nodes with a single input and no control dependencies.
2357 //
2358 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
2359 // deduping or other rewrites) so we can get rid of the sum entirely.
2360 //
2361 // The expression (using AddN as an example of an aggregate op):
2362 // AddN(x, x, x, ... ,x)
2363 // <-- N terms -->
2364 // can be rewritten to:
2365 // Mul(Const(N), x))
2366 //
2367 class SimplifyAggregation : public ArithmeticOptimizerStage {
2368 public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2369 explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
2370 const ArithmeticOptimizerContext& ctx_ext)
2371 : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
2372 ~SimplifyAggregation() override = default;
2373
IsSupported(const NodeDef * node) const2374 bool IsSupported(const NodeDef* node) const override {
2375 return IsAggregate(*node) && NumNonControlInputs(*node) > 0 &&
2376 GetDataTypeFromAttr(*node, "T") !=
2377 DT_VARIANT; // TODO(b/119787146): Enable for variants.
2378 }
2379
TrySimplify(NodeDef * node,string * simplified_node_name)2380 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2381 // 1. Discard aggregate nodes with a single input and no control deps.
2382 if (node->input_size() == 1) {
2383 *simplified_node_name = node->input(0);
2384 return Status::OK();
2385 }
2386
2387 // 2. Rewrite aggregations of N >= 2 identical terms.
2388
2389 // All non-control inputs must be identical.
2390 bool all_equal = true;
2391 int num_inputs = 1;
2392 for (int i = 1; i < node->input_size(); ++i) {
2393 if (IsControlInput(node->input(i))) break;
2394 ++num_inputs;
2395 if (node->input(i) != node->input(0)) {
2396 all_equal = false;
2397 break;
2398 }
2399 }
2400 if (!all_equal) return Status::OK();
2401
2402 // And node should not be optimized earlier.
2403 const NodeScopeAndName node_scope_and_name =
2404 ParseNodeScopeAndName(node->name());
2405 const string optimized_const_name =
2406 OptimizedNodeName(node_scope_and_name, "Const");
2407 const string optimized_mul_name =
2408 OptimizedNodeName(node_scope_and_name, "Mul");
2409
2410 bool is_already_optimized =
2411 ctx().node_map->NodeExists(optimized_const_name) ||
2412 ctx().node_map->NodeExists(optimized_mul_name);
2413
2414 if (is_already_optimized) return Status::OK();
2415
2416 // At this point all preconditions are met, and we safely do the rewrite.
2417 VLOG(3) << "Simplify aggregation with identical inputs: node="
2418 << node->name() << " num_inputs=" << num_inputs;
2419
2420 // 1. Create constant node with value N.
2421 const auto type = GetDataTypeFromAttr(*node, "T");
2422 Tensor t(type, TensorShape({}));
2423 Status status = SetTensorValue(type, num_inputs, &t);
2424 if (!status.ok()) {
2425 return errors::Internal("Failed to create const node: ",
2426 status.error_message());
2427 }
2428
2429 TensorValue value(&t);
2430 NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
2431 status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
2432 new_const_node);
2433 if (!status.ok()) {
2434 return errors::Internal("Failed to create const node: ",
2435 status.error_message());
2436 }
2437 new_const_node->set_device(node->device());
2438 MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
2439 ctx().optimized_graph, ctx().node_map);
2440 AddToOptimizationQueue(new_const_node);
2441
2442 // 2. Replace the aggregate node with Mul(Const(N), x).
2443 NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
2444 new_mul_node->set_op("Mul");
2445 new_mul_node->set_device(node->device());
2446 SetDataTypeToAttr(type, "T", new_mul_node);
2447 new_mul_node->add_input(new_const_node->name());
2448 ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
2449 new_mul_node->add_input(node->input(0));
2450 ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
2451
2452 ForwardControlDependencies(new_mul_node, {node});
2453 *simplified_node_name = new_mul_node->name();
2454
2455 return Status::OK();
2456 }
2457 };
2458
2459 class ConvertPowStage : public ArithmeticOptimizerStage {
2460 public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2461 explicit ConvertPowStage(const GraphOptimizerContext& ctx,
2462 const ArithmeticOptimizerContext& ctx_ext)
2463 : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
2464
IsSupported(const NodeDef * node) const2465 bool IsSupported(const NodeDef* node) const override {
2466 return IsPow(*node) &&
2467 ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
2468 }
2469
TrySimplify(NodeDef * node,string * simplified_node_name)2470 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2471 const auto& pow_props =
2472 ctx().graph_properties->GetInputProperties(node->name())[1];
2473 PartialTensorShape shape(pow_props.shape());
2474 if (!shape.IsFullyDefined()) {
2475 // skip if p is not fully defined.
2476 return Status::OK();
2477 }
2478 if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
2479 Tensor pow(pow_props.dtype(), pow_props.shape());
2480 if (!pow.FromProto(pow_props.value())) {
2481 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2482 pow_props.value().DebugString());
2483 }
2484
2485 complex128 prev, curr;
2486 for (int i = 0; i < pow.NumElements(); ++i) {
2487 if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
2488 // input data type is not supported by Pow. Skip.
2489 return Status::OK();
2490 }
2491 if (i != 0 && curr != prev) {
2492 // pow has different values on different elements. Skip.
2493 return Status::OK();
2494 }
2495 prev = curr;
2496 }
2497 NodeDef *x, *y;
2498 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
2499 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
2500 const auto& value_props =
2501 ctx().graph_properties->GetInputProperties(node->name())[0];
2502 const TensorShapeProto& output_shape =
2503 ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
2504 if (curr == complex128(2, 0)) {
2505 node->set_op("Square");
2506 node->set_input(1, AsControlDependency(y->name()));
2507 AddToOptimizationQueue(node);
2508 AddToOptimizationQueue(y);
2509 } else if (curr == complex128(1, 0) &&
2510 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2511 // Pow could be used to broadcast, so make sure the shapes of the two
2512 // arguments are identical before replacing Pow with Identity.
2513 node->set_op("Identity");
2514 node->set_input(1, AsControlDependency(y->name()));
2515 AddToOptimizationQueue(node);
2516 AddToOptimizationQueue(y);
2517 } else if (curr == complex128(0.5, 0)) {
2518 node->set_op("Sqrt");
2519 node->set_input(1, AsControlDependency(y->name()));
2520 AddToOptimizationQueue(node);
2521 AddToOptimizationQueue(y);
2522 } else if (curr == complex128(0, 0) &&
2523 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
2524 PartialTensorShape shape(value_props.shape());
2525 if (!shape.IsFullyDefined()) {
2526 // skip if b is not fully defined.
2527 return Status::OK();
2528 }
2529 if (TensorShape::IsValid(value_props.shape()) &&
2530 value_props.has_value()) {
2531 Tensor base(value_props.dtype(), value_props.shape());
2532 if (!base.FromProto(value_props.value())) {
2533 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2534 value_props.value().DebugString());
2535 }
2536 node->set_op("Const");
2537 Tensor c(base.dtype(), base.shape());
2538 for (int i = 0; i < c.NumElements(); ++i) {
2539 TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
2540 }
2541 (*node->mutable_attr())["dtype"].set_type(base.dtype());
2542 c.AsProtoTensorContent(
2543 (*node->mutable_attr())["value"].mutable_tensor());
2544 node->mutable_attr()->erase("T");
2545 node->set_input(0, AsControlDependency(x->name()));
2546 node->set_input(1, AsControlDependency(y->name()));
2547 AddToOptimizationQueue(node);
2548 AddToOptimizationQueue(x);
2549 AddToOptimizationQueue(y);
2550 }
2551 } else if (curr == complex128(-0.5, 0)) {
2552 node->set_op("Rsqrt");
2553 node->set_input(1, AsControlDependency(y->name()));
2554 AddToOptimizationQueue(node);
2555 AddToOptimizationQueue(y);
2556 } else if (curr == complex128(-1, 0)) {
2557 node->set_op("Reciprocal");
2558 node->set_input(1, AsControlDependency(y->name()));
2559 AddToOptimizationQueue(node);
2560 AddToOptimizationQueue(y);
2561 }
2562 }
2563 return Status::OK();
2564 }
2565
2566 private:
SetElementToOne(int i,Tensor * t)2567 Status SetElementToOne(int i, Tensor* t) {
2568 switch (t->dtype()) {
2569 case DT_INT32:
2570 t->flat<int32>()(i) = 1;
2571 return Status::OK();
2572 case DT_INT64:
2573 t->flat<int64>()(i) = 1L;
2574 return Status::OK();
2575 case DT_FLOAT:
2576 t->flat<float>()(i) = 1.0f;
2577 return Status::OK();
2578 case DT_DOUBLE:
2579 t->flat<double>()(i) = 1.0;
2580 return Status::OK();
2581 case DT_COMPLEX64:
2582 t->flat<complex64>()(i) = complex64(1);
2583 return Status::OK();
2584 case DT_COMPLEX128:
2585 t->flat<complex128>()(i) = complex128(1);
2586 return Status::OK();
2587 default:
2588 return errors::InvalidArgument("Invalid data type: ", t->dtype());
2589 }
2590 }
2591 };
2592
2593 class ConvertLog1pStage : public ArithmeticOptimizerStage {
2594 public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2595 explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
2596 const ArithmeticOptimizerContext& ctx_ext)
2597 : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
2598 ~ConvertLog1pStage() override = default;
2599
IsSupported(const NodeDef * node) const2600 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
2601
TrySimplify(NodeDef * node,string * simplified_node_name)2602 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2603 NodeDef* input;
2604 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2605 if (!IsAdd(*input)) {
2606 return Status::OK();
2607 }
2608
2609 if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
2610 return Status::OK();
2611 }
2612
2613 bool modified = false;
2614 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
2615 if (!modified) {
2616 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
2617 }
2618 if (modified) {
2619 *simplified_node_name = node->name();
2620 }
2621 return Status::OK();
2622 }
2623
2624 private:
TrySimplifyInternal(NodeDef * node,NodeDef * input,int i,int j,bool * modified)2625 Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
2626 bool* modified) {
2627 const auto& t =
2628 ctx().graph_properties->GetInputProperties(input->name())[i];
2629 const auto& c =
2630 ctx().graph_properties->GetInputProperties(input->name())[j];
2631 for (int k = 0; k < c.shape().dim_size(); ++k) {
2632 // Skip if c shape is not fully determined.
2633 if (c.shape().dim(k).size() < 0) {
2634 return Status::OK();
2635 }
2636 }
2637 TensorShapeProto broadcast_shape;
2638 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2639 return Status::OK();
2640 }
2641 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2642 // skip if the non-constant tensor doesn't have the same shape after
2643 // broadcast.
2644 return Status::OK();
2645 }
2646 if (TensorShape::IsValid(c.shape()) && c.has_value()) {
2647 Tensor constant(c.dtype(), c.shape());
2648 if (!constant.FromProto(c.value())) {
2649 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2650 c.value().DebugString());
2651 }
2652 complex128 element;
2653 for (int k = 0; k < constant.NumElements(); ++k) {
2654 if (!GetElementUnexhaustive(constant, k,
2655 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2656 DT_COMPLEX64, DT_COMPLEX128},
2657 &element)) {
2658 // input data type is not supported by log1p. Skip.
2659 return Status::OK();
2660 }
2661 if (element != complex128(1)) {
2662 // current element is not 1. Skip.
2663 return Status::OK();
2664 }
2665 }
2666 NodeDef *x, *y;
2667 TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
2668 TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
2669 node->set_op("Log1p");
2670 node->set_input(0, input->input(i));
2671 node->add_input(AsControlDependency(y->name()));
2672 ForwardControlDependencies(node, {input});
2673
2674 AddToOptimizationQueue(node);
2675 AddToOptimizationQueue(input);
2676 AddToOptimizationQueue(x);
2677 AddToOptimizationQueue(y);
2678 *modified = true;
2679 }
2680 return Status::OK();
2681 }
2682 };
2683
2684 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
2685 public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2686 explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
2687 const ArithmeticOptimizerContext& ctx_ext)
2688 : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
2689 ~ConvertExpm1Stage() override = default;
2690
IsSupported(const NodeDef * node) const2691 bool IsSupported(const NodeDef* node) const override {
2692 if (!IsSub(*node)) return false;
2693
2694 NodeDef* input;
2695 if (!GetInputNode(node->input(0), &input).ok()) return false;
2696
2697 return IsExp(*input);
2698 }
2699
TrySimplify(NodeDef * node,string * simplified_node_name)2700 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2701 if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
2702 return Status::OK();
2703 }
2704
2705 NodeDef* exp;
2706 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
2707 if (!IsExp(*exp)) {
2708 return Status::OK();
2709 }
2710
2711 if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
2712 return Status::OK();
2713 }
2714
2715 const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
2716 const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
2717 for (int k = 0; k < c.shape().dim_size(); ++k) {
2718 // Skip if c shape is not fully determined.
2719 if (c.shape().dim(k).size() < 0) {
2720 return Status::OK();
2721 }
2722 }
2723 TensorShapeProto broadcast_shape;
2724 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
2725 return Status::OK();
2726 }
2727 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
2728 // skip if the non-constant tensor doesn't have the same shape after
2729 // broadcast.
2730 return Status::OK();
2731 }
2732 if (TensorShape::IsValid(c.shape()) && c.has_value()) {
2733 Tensor constant(c.dtype(), c.shape());
2734 if (!constant.FromProto(c.value())) {
2735 return errors::InvalidArgument("Cannot parse tensor from proto: ",
2736 c.value().DebugString());
2737 }
2738 complex128 element;
2739 for (int k = 0; k < constant.NumElements(); ++k) {
2740 if (!GetElementUnexhaustive(constant, k,
2741 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
2742 DT_COMPLEX64, DT_COMPLEX128},
2743 &element)) {
2744 // input data type is not supported by expm1. Skip.
2745 return Status::OK();
2746 }
2747 if (element != complex128(1)) {
2748 // current element is not 1. Skip.
2749 return Status::OK();
2750 }
2751 }
2752 NodeDef *exp_input, *ones;
2753 TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
2754 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2755 node->set_op("Expm1");
2756 node->set_input(0, exp->input(0));
2757 node->set_input(1, AsControlDependency(ones->name()));
2758 ForwardControlDependencies(node, {exp});
2759
2760 AddToOptimizationQueue(node);
2761 AddToOptimizationQueue(exp);
2762 AddToOptimizationQueue(exp_input);
2763 AddToOptimizationQueue(ones);
2764 }
2765 return Status::OK();
2766 }
2767 };
2768
2769 // Performs conversions like:
2770 // Max(Sqrt(x)) => Sqrt(Max(x))
2771 // Checks for a max/min reduction over element-wise monotonic functions, such
2772 // as Sqrt, Sigmoid, Tanh, etc.
2773 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
2774 public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2775 explicit OptimizeMaxOrMinOfMonotonicStage(
2776 const GraphOptimizerContext& ctx,
2777 const ArithmeticOptimizerContext& ctx_ext)
2778 : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
2779 ctx_ext) {}
2780 ~OptimizeMaxOrMinOfMonotonicStage() override = default;
2781
IsSupported(const NodeDef * node) const2782 bool IsSupported(const NodeDef* node) const override {
2783 return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
2784 IsArgMax(*node) || IsArgMin(*node);
2785 }
2786
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)2787 Status TrySimplify(NodeDef* reduction_node,
2788 string* simplified_node_name) override {
2789 if (IsInPreserveSet(*reduction_node)) {
2790 return Status::OK();
2791 }
2792 NodeDef* inner_function;
2793 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
2794 // Optimize only if:
2795 // 0. inner_function is not in the preserve set,
2796 // 1. inner_function's Op is element-wise monotonic
2797 // 2. inner_function's output is not being consumed elsewhere.
2798 // 3. is monotonic increasing if reduction_node is a pooling operation
2799 // since we don't have MinPool operations.
2800 bool is_non_decreasing = false;
2801 if (!IsInPreserveSet(*inner_function) &&
2802 IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
2803 ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
2804 (is_non_decreasing || !IsAnyMaxPool(*reduction_node))) {
2805 // Swap the first inputs of the inner function Op & the reduction Op.
2806 NodeDef* inner_input;
2807 TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
2808 reduction_node->set_input(0, inner_input->name());
2809 ctx().node_map->UpdateInput(reduction_node->name(),
2810 inner_function->name(), inner_input->name());
2811 inner_function->set_input(0, reduction_node->name());
2812 UpdateConsumers(reduction_node, inner_function->name());
2813 ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
2814 reduction_node->name());
2815 if (!is_non_decreasing) {
2816 // Flip Min<->Max if the function is non-increasing, e.g.
2817 // Max(Neg(x)) = Neg(Min(x)).
2818 const string opposite = FlipMinMax(*reduction_node);
2819 reduction_node->set_op(opposite);
2820 }
2821
2822 if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
2823 // ArgMax(Sqrt(x)) = ArgMax(x)
2824 inner_function->set_op("Identity");
2825 }
2826
2827 AddToOptimizationQueue(reduction_node);
2828 AddToOptimizationQueue(inner_function);
2829 AddToOptimizationQueue(inner_input);
2830 }
2831 return Status::OK();
2832 }
2833
UpdateConsumers(NodeDef * node,const string & new_input)2834 void UpdateConsumers(NodeDef* node, const string& new_input) {
2835 const string& node_name = node->name();
2836 const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
2837 for (NodeDef* consumer : consumers) {
2838 for (int i = 0; i < consumer->input_size(); ++i) {
2839 if (consumer->input(i) == node_name && consumer->name() != new_input) {
2840 consumer->set_input(i, new_input);
2841 ctx().node_map->UpdateInput(consumer->name(), node_name, new_input);
2842 }
2843 }
2844 AddToOptimizationQueue(consumer);
2845 }
2846 }
2847
2848 private:
FlipMinMax(const NodeDef & node)2849 string FlipMinMax(const NodeDef& node) {
2850 const string& op = node.op();
2851 if (IsAnyMax(node) || IsArgMax(node)) {
2852 return str_util::StringReplace(op, "Max", "Min", false);
2853 } else {
2854 return str_util::StringReplace(op, "Min", "Max", false);
2855 }
2856 }
2857 };
2858
2859 // Replace a chain of type&shape preserving unary ops with a
2860 // '_UnaryOpsComposition' node.
2861 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
2862 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
2863 class UnaryOpsComposition : public ArithmeticOptimizerStage {
2864 public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2865 explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
2866 const ArithmeticOptimizerContext& ctx_ext)
2867 : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
2868 // WARN: This should be consistent with unary_ops_composition.cc.
2869 // clang-format off
2870 supported_ops_ = {// Ops defined via Eigen scalar ops.
2871 {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2872 {"Acos", {DT_FLOAT, DT_DOUBLE}},
2873 {"Acosh", {DT_FLOAT, DT_DOUBLE}},
2874 {"Asin", {DT_FLOAT, DT_DOUBLE}},
2875 {"Asinh", {DT_FLOAT, DT_DOUBLE}},
2876 {"Atan", {DT_FLOAT, DT_DOUBLE}},
2877 {"Atanh", {DT_FLOAT, DT_DOUBLE}},
2878 {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2879 {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2880 {"Cosh", {DT_FLOAT, DT_DOUBLE}},
2881 {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2882 {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2883 {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2884 {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2885 {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2886 {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2887 {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2888 {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2889 {"Rint", {DT_FLOAT, DT_DOUBLE}},
2890 {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2891 {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2892 {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2893 {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2894 {"Sinh", {DT_FLOAT, DT_DOUBLE}},
2895 {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2896 {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2897 {"Tan", {DT_FLOAT, DT_DOUBLE}},
2898 {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2899 // Additional ops that are not part of the Eigen.
2900 {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2901 {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2902 {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
2903 {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
2904 // clang-format on
2905 }
2906 ~UnaryOpsComposition() override = default;
2907
IsSupported(const NodeDef * node) const2908 bool IsSupported(const NodeDef* node) const override {
2909 return CanOptimize(*node) &&
2910 // Check that this node was not already a root of a fused chain. If
2911 // graph optimization runs twice without pruning in between,
2912 // fused_nodes_ will not have this information.
2913 !ctx().node_map->NodeExists(OptimizedNodeName(*node));
2914 }
2915
TrySimplify(NodeDef * root,string * simplified_node_name)2916 Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
2917 TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
2918 DataType dtype = root->attr().at("T").type();
2919
2920 // Keep a trace of all supported input nodes that can be fused together.
2921 std::vector<string> op_nodes = {root->name()};
2922 std::vector<string> op_names = {root->op()};
2923
2924 // Check if we should follow input(0) while building an op composition.
2925 const auto predicate_fn = [&](const NodeDef& input) {
2926 if (input.name() == root->name()) return true;
2927
2928 bool follow_input_node =
2929 dtype == GetDataTypeFromAttr(input, "T") &&
2930 NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
2931 CanOptimize(input);
2932
2933 if (follow_input_node) {
2934 op_nodes.push_back(input.name());
2935 op_names.push_back(input.op());
2936 }
2937
2938 return follow_input_node;
2939 };
2940
2941 NodeDef* last_op = GetTailOfChain(
2942 *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
2943
2944 // We were not able to find a chain that can be replaced.
2945 if (op_names.size() == 1) return Status::OK();
2946
2947 // Do not add fused nodes to any other chain.
2948 std::for_each(op_nodes.begin(), op_nodes.end(),
2949 [this](const string& name) { AddToFusedNodes(name); });
2950
2951 // Reverse the trace to get correct composition computation order.
2952 std::reverse(op_names.begin(), op_names.end());
2953
2954 VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
2955 << str_util::Join(op_names, ", ") << "]";
2956
2957 NodeDef* composition_node = ctx().optimized_graph->add_node();
2958 composition_node->set_name(OptimizedNodeName(*root));
2959 composition_node->set_op("_UnaryOpsComposition");
2960 composition_node->add_input(last_op->input(0));
2961 composition_node->set_device(root->device());
2962
2963 auto attr = composition_node->mutable_attr();
2964 SetAttrValue(dtype, &(*attr)["T"]);
2965 SetAttrValue(op_names, &(*attr)["op_names"]);
2966
2967 ctx().node_map->AddNode(composition_node->name(), composition_node);
2968 ctx().node_map->AddOutput(NodeName(last_op->input(0)),
2969 composition_node->name());
2970
2971 *simplified_node_name = composition_node->name();
2972
2973 return Status::OK();
2974 }
2975
2976 private:
CanOptimize(const NodeDef & node) const2977 bool CanOptimize(const NodeDef& node) const {
2978 DataType dtype = GetDataTypeFromAttr(node, "T");
2979 if (!IsSupported(node.op(), dtype)) {
2980 return false;
2981 }
2982 if (IsInPreserveSet(node)) {
2983 return false;
2984 }
2985 if (!NodeIsOnCpu(node)) {
2986 return false;
2987 }
2988 if (NodeIsAlreadyFused(node)) {
2989 return false;
2990 }
2991 return !(IsDrivenByControlDependency(node) ||
2992 DrivesControlDependency(node));
2993 }
2994
2995 // UnaryOpsComposition is defined only for CPU.
NodeIsOnCpu(const NodeDef & node) const2996 bool NodeIsOnCpu(const NodeDef& node) const {
2997 using str_util::StartsWith;
2998
2999 string task;
3000 string device;
3001
3002 return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
3003 StartsWith(device, DEVICE_CPU);
3004 }
3005
NodeIsAlreadyFused(const NodeDef & node) const3006 bool NodeIsAlreadyFused(const NodeDef& node) const {
3007 return fused_nodes_.count(node.name()) > 0;
3008 }
3009
OptimizedNodeName(const NodeDef & node) const3010 string OptimizedNodeName(const NodeDef& node) const {
3011 return strings::StrCat(node.name(), "/unary_ops_composition");
3012 }
3013
AddToFusedNodes(const string & name)3014 void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3015
3016 // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3017 bool IsSupported(const string& op_name, DataType dtype) const {
3018 const auto it = supported_ops_.find(op_name);
3019 return it != supported_ops_.end() && it->second.count(dtype) > 0;
3020 }
3021
3022 std::unordered_map<string, std::set<DataType>> supported_ops_;
3023 std::unordered_set<string> fused_nodes_;
3024 };
3025
3026 // Replace operations of the form:
3027 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3028 // with
3029 // a_i
3030 // when the strided slice index `i` is applied in the k'th axis.
3031 //
3032 // Similarly, replace operations of the form:
3033 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3034 // with
3035 // expand_dims(a_i, axis=k)
3036 //
3037 // TODO(ebrevdo): Extend to also replace operations of the form
3038 // concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3039 // with
3040 // a_i,
3041 // when
3042 // s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3043 // and slicing is in the k'th axis.
3044 class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
3045 public:
RemoveStackStridedSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3046 explicit RemoveStackStridedSliceSameAxis(
3047 const GraphOptimizerContext& ctx,
3048 const ArithmeticOptimizerContext& ctx_ext)
3049 : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3050 ctx_ext) {}
3051 ~RemoveStackStridedSliceSameAxis() override = default;
3052
IsSupported(const NodeDef * node) const3053 bool IsSupported(const NodeDef* node) const override {
3054 return IsStridedSlice(*node);
3055 }
3056
TrySimplify(NodeDef * node,string * simplified_node_name)3057 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3058 // *node is a StridedSlice NodeDef.
3059 NodeDef* pack;
3060
3061 // Get the input and see if it's a Pack op.
3062 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3063 if (!IsPack(*pack)) return Status::OK();
3064
3065 bool return_early;
3066 PartialTensorShape pack_output_shape;
3067 int pack_axis;
3068 TF_RETURN_IF_ERROR(
3069 CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3070 if (return_early) return Status::OK();
3071
3072 int slice_start_value;
3073 bool found;
3074 TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3075 &slice_start_value, &found));
3076 if (!found) return Status::OK();
3077
3078 return RewriteGraph(node, pack, slice_start_value, pack_axis,
3079 simplified_node_name);
3080 }
3081
3082 protected:
IsReallyConstant(const NodeDef & node) const3083 bool IsReallyConstant(const NodeDef& node) const {
3084 if (!IsConstant(node)) {
3085 return false;
3086 }
3087 // If the node is fed it's not constant anymore.
3088 return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
3089 }
3090
GetConstantAsInt64(const NodeDef & node,DataType dtype,std::vector<int64> * values)3091 bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
3092 std::vector<int64>* values) {
3093 if (dtype == DT_INT32) {
3094 std::vector<int32> values_int32;
3095 if (!ValuesFromConstNode(node, &values_int32)) {
3096 return false;
3097 }
3098 std::copy(values_int32.begin(), values_int32.end(),
3099 std::inserter(*values, values->begin()));
3100 return true;
3101 } else {
3102 return ValuesFromConstNode(node, values);
3103 }
3104 }
3105
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3106 Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3107 PartialTensorShape* pack_output_shape, int* pack_axis,
3108 bool* return_early) {
3109 *return_early = true;
3110 TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3111
3112 *pack_axis = pack->attr().at("axis").i();
3113 auto slice_properties =
3114 ctx().graph_properties->GetInputProperties(node->name());
3115 if (slice_properties.empty() ||
3116 slice_properties[0].shape().unknown_rank()) {
3117 return Status::OK();
3118 }
3119 *pack_output_shape = slice_properties[0].shape();
3120 const int pack_input_rank = pack_output_shape->dims() - 1;
3121 if (*pack_axis < 0) {
3122 // The ndims of any input into Pack op is its output ndims - 1.
3123 *pack_axis += pack_input_rank;
3124 }
3125 if (*pack_axis < 0 || *pack_axis >= pack_input_rank) {
3126 return errors::InvalidArgument(
3127 "Pack node (", pack->name(),
3128 ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3129 }
3130 *return_early = false;
3131 return Status::OK();
3132 }
3133
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int * slice_start_value,bool * found)3134 Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3135 const PartialTensorShape& pack_output_shape,
3136 int pack_axis, int* slice_start_value, bool* found) {
3137 *found = false;
3138 TF_RETURN_IF_ERROR(
3139 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3140 "new_axis_mask", "shrink_axis_mask"}));
3141
3142 const int begin_mask = node->attr().at("begin_mask").i();
3143 const int end_mask = node->attr().at("end_mask").i();
3144 const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3145 const int new_axis_mask = node->attr().at("new_axis_mask").i();
3146 const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3147
3148 // Check that the StridedSlice is one of these at pack_axis:
3149 // [..., i, ...]
3150 // [..., i:i+1, ...]
3151 // [..., :1, ...]
3152 // [..., -1:, ...]
3153 /// [..., s_{pack_axis}-1:, ...]
3154 NodeDef* slice_begin;
3155 NodeDef* slice_end;
3156 NodeDef* slice_strides;
3157 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3158 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3159 TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3160
3161 for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3162 if (!IsReallyConstant(*n)) return Status::OK();
3163 }
3164
3165 Tensor slice_begin_t;
3166 Tensor slice_end_t;
3167 Tensor slice_strides_t;
3168
3169 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3170 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3171 return Status::OK();
3172 }
3173 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3174 if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3175 return Status::OK();
3176 }
3177 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3178 if (!slice_strides_t.FromProto(
3179 slice_strides->attr().at("value").tensor())) {
3180 return Status::OK();
3181 }
3182 TensorShape processing_shape;
3183 TensorShape final_shape;
3184 bool is_identity;
3185 bool is_simple_slice;
3186 bool slice_dim0;
3187 gtl::InlinedVector<int64, 4> slice_begin_vec;
3188 gtl::InlinedVector<int64, 4> slice_end_vec;
3189 gtl::InlinedVector<int64, 4> slice_strides_vec;
3190 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3191 &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3192 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3193 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3194 &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3195
3196 if (!is_simple_slice) return Status::OK();
3197
3198 int begin_index = -1;
3199 int64 begin_value = 0;
3200 for (int i = 0; i < slice_begin_vec.size(); ++i) {
3201 const int64 v = slice_begin_vec[i];
3202 if (v != 0) {
3203 if (begin_index != -1) {
3204 // At least two start values that are nonzero.
3205 return Status::OK();
3206 }
3207 begin_index = i;
3208 begin_value = v;
3209 }
3210 }
3211
3212 int end_index = -1;
3213 int64 end_value = 0;
3214 for (int i = 0; i < slice_end_vec.size(); ++i) {
3215 const int64 v = slice_end_vec[i];
3216 if (v != pack_output_shape.dim_size(i)) {
3217 if (end_index != -1) {
3218 // At least two end values that are nonzero.
3219 return Status::OK();
3220 }
3221 end_index = i;
3222 end_value = v;
3223 }
3224 }
3225
3226 if (begin_index == -1 && end_index == -1) return Status::OK();
3227 if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3228 // Somehow received different axes for begin/end slicing
3229 return Status::OK();
3230 }
3231 const int slice_axis = (begin_index == -1) ? end_index : begin_index;
3232 if (slice_axis != pack_axis) {
3233 // Not slicing on the same axis as the Pack op.
3234 return Status::OK();
3235 }
3236 *slice_start_value = (begin_index == -1) ? 0 : begin_value;
3237 const int64 slice_end_value =
3238 (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
3239 if (slice_end_value != *slice_start_value + 1) {
3240 // Not slicing a single value out.
3241 return Status::OK();
3242 }
3243
3244 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3245 return errors::InvalidArgument(
3246 "Node ", node->name(), " requested invalid slice index ",
3247 *slice_start_value, " on axis ", slice_axis,
3248 " from tensor of shape: ", pack_output_shape.DebugString());
3249 }
3250
3251 *found = true; // slice_start_value is valid.
3252 return Status::OK();
3253 }
3254
RewriteGraph(const NodeDef * node,const NodeDef * pack,int slice_start_value,int pack_axis,string * simplified_node_name)3255 Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
3256 int slice_start_value, int pack_axis,
3257 string* simplified_node_name) {
3258 OpInfo::TensorProperties input_slice_properties;
3259 NodeDef* input_slice;
3260 TF_RETURN_IF_ERROR(
3261 GetInputNode(pack->input(slice_start_value), &input_slice));
3262 TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
3263 &input_slice_properties));
3264 PartialTensorShape input_slice_shape(input_slice_properties.shape());
3265
3266 OpInfo::TensorProperties output_properties;
3267 TF_RETURN_IF_ERROR(GetTensorProperties(
3268 strings::StrCat(node->name(), ":", 0), &output_properties));
3269 PartialTensorShape output_shape(output_properties.shape());
3270 NodeDef* output =
3271 AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
3272 if (input_slice_shape.IsCompatibleWith(output_shape)) {
3273 output->set_op("Identity");
3274 output->set_device(node->device());
3275 SetDataTypeToAttr(output_properties.dtype(), "T", output);
3276 output->add_input(input_slice->name());
3277 } else {
3278 NodeDef* axis = AddEmptyNode(
3279 OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
3280 axis->set_op("Const");
3281 axis->set_device(node->device());
3282 auto axis_attr = axis->mutable_attr();
3283 SetDataTypeToAttr(DT_INT32, "dtype", axis);
3284 auto* axis_t = (*axis_attr)["value"].mutable_tensor();
3285 axis_t->set_dtype(DT_INT32);
3286 axis_t->add_int_val(pack_axis);
3287 AddToOptimizationQueue(axis);
3288 output->set_op("ExpandDims");
3289 output->set_device(node->device());
3290 SetDataTypeToAttr(output_properties.dtype(), "T", output);
3291 output->add_input(input_slice->name());
3292 output->add_input(axis->name());
3293 }
3294
3295 // Copy dependencies over.
3296 ForwardControlDependencies(output, {node, pack});
3297 AddToOptimizationQueue(output);
3298 *simplified_node_name = output->name();
3299
3300 return Status::OK();
3301 }
3302 };
3303
3304 } // namespace
3305
3306 class UniqueNodes {
3307 public:
FindOrAddRepresentative(NodeDef * node)3308 NodeDef* FindOrAddRepresentative(NodeDef* node) {
3309 uint64 sig = ComputeSignature(*node);
3310 std::vector<NodeDef*>& candidates = rep_[sig];
3311 for (auto& candidate : candidates) {
3312 if (SameNode(*candidate, *node)) {
3313 return candidate;
3314 }
3315 }
3316 candidates.push_back(node);
3317 return node;
3318 }
3319
3320 private:
3321 uint64 ComputeSignature(const NodeDef& node);
3322 bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
3323
3324 absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
3325 absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
3326 };
3327
ComputeSignature(const NodeDef & node)3328 uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
3329 auto it = memoized_signatures_.find(&node);
3330 if (it != memoized_signatures_.end()) return it->second;
3331
3332 uint64 h = Hash64(node.op());
3333 h = Hash64Combine(Hash64(node.device()), h);
3334
3335 for (const auto& input : node.input()) {
3336 const TensorId input_tensor = ParseTensorName(input);
3337 h = Hash64CombineUnordered(
3338 Hash64(input_tensor.node().data(), input_tensor.node().size()), h);
3339 h = Hash64CombineUnordered(std::hash<int>()(input_tensor.index()), h);
3340 }
3341 for (const auto& attr : node.attr()) {
3342 h = Hash64CombineUnordered(Hash64(attr.first), h);
3343 h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
3344 }
3345 memoized_signatures_.emplace(&node, h);
3346 return h;
3347 }
3348
SameNode(const NodeDef & node1,const NodeDef & node2) const3349 bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
3350 if (node1.op() != node2.op()) {
3351 return false;
3352 }
3353 if (node1.device() != node2.device()) {
3354 return false;
3355 }
3356 if (node1.input_size() != node2.input_size()) {
3357 return false;
3358 }
3359 if (node1.attr_size() != node2.attr_size()) {
3360 return false;
3361 }
3362
3363 // Compare inputs.
3364 if (IsCommutative(node1)) {
3365 std::vector<string> inputs1(node1.input().begin(), node1.input().end());
3366 std::sort(inputs1.begin(), inputs1.end());
3367 std::vector<string> inputs2(node2.input().begin(), node2.input().end());
3368 std::sort(inputs2.begin(), inputs2.end());
3369 return inputs1 == inputs2;
3370 } else {
3371 // The order or ordinary inputs matters.
3372 int index = 0;
3373 for (; index < node1.input_size(); ++index) {
3374 if (IsControlInput(node1.input(index))) {
3375 break;
3376 } else if (node1.input(index) != node2.input(index)) {
3377 return false;
3378 }
3379 }
3380 // The order of control inputs does not matter.
3381 if (index < node1.input_size()) {
3382 std::vector<string> ctrl_inputs1(node1.input().begin() + index,
3383 node1.input().end());
3384 std::sort(ctrl_inputs1.begin(), ctrl_inputs1.end());
3385 std::vector<string> ctrl_inputs2(node2.input().begin() + index,
3386 node2.input().end());
3387 std::sort(ctrl_inputs2.begin(), ctrl_inputs2.end());
3388 return ctrl_inputs1 != ctrl_inputs2;
3389 }
3390 }
3391
3392 // Compare attributes.
3393 if (node1.attr().size() != node2.attr().size()) {
3394 return false;
3395 }
3396 for (const auto& attr1 : node1.attr()) {
3397 auto it = node2.attr().find(attr1.first);
3398 if (it == node2.attr().end()) return false;
3399 if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
3400 }
3401
3402 return true;
3403 }
3404
CanDedup(const NodeDef & node) const3405 bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
3406 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
3407 return false;
3408 }
3409 if (IsEnter(node) || IsExit(node)) {
3410 return false;
3411 }
3412 if (node.device().find("SPU") != string::npos) {
3413 return false;
3414 }
3415 // Workaround for Assert and Print mistakenly being labeled as stateful.
3416 if (IsAssert(node) || IsPrint(node)) {
3417 return true;
3418 }
3419 return IsFreeOfSideEffect(node);
3420 }
3421
DedupComputations()3422 void ArithmeticOptimizer::DedupComputations() {
3423 GraphTopologyView graph_view;
3424 if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
3425 LOG(WARNING) << "Failed to initialize GraphTopologyView.";
3426 return;
3427 }
3428
3429 const absl::flat_hash_set<string> ops_to_traverse = {
3430 "Identity", "IdentityN", "Reshape", "ExpandDims",
3431 "Enter", "Switch", "Merge"};
3432
3433 // Populate feed_inplace_op;
3434 absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
3435
3436 for (const NodeDef& root : optimized_graph_->node()) {
3437 if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
3438
3439 if (ModifiesInputsInPlace(root)) {
3440 const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
3441 return node->op() == root.op() || ops_to_traverse.count(node->op()) > 0;
3442 };
3443
3444 DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
3445 DfsPredicates::Advance(is_continue_traversal),
3446 DfsCallbacks::PreOrder([&](const NodeDef* node) {
3447 feeds_inplace_op.insert(node);
3448 }));
3449 }
3450 }
3451
3452 bool stop = true;
3453 std::set<int> duplicates;
3454 UniqueNodes nodes;
3455 do {
3456 stop = true;
3457 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3458 if (duplicates.find(i) != duplicates.end()) {
3459 continue;
3460 }
3461 NodeDef* node = optimized_graph_->mutable_node(i);
3462 if (!CanDedup(*node) ||
3463 feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
3464 continue;
3465 }
3466 NodeDef* rep = nodes.FindOrAddRepresentative(node);
3467 if (rep == node) {
3468 continue;
3469 }
3470 // If either node or rep feeds an inplace op, deduping them may cause data
3471 // races. For example: If we dedup nodes initializing two independent
3472 // inplace accumulations, they will write to the same buffer, clobbering
3473 // each other's results.
3474 if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
3475 continue;
3476 }
3477 VLOG(3) << "Remove duplicated node: node=" << node->name()
3478 << " representative=" << rep->name();
3479 const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
3480 std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
3481 for (NodeDef* fanout : fanouts) {
3482 for (int i = 0; i < fanout->input_size(); ++i) {
3483 string* fanout_input = fanout->mutable_input(i);
3484 const int position =
3485 NodePositionIfSameNode(*fanout_input, node->name());
3486 // Update name in-place.
3487 if (position < -1) {
3488 continue;
3489 } else if (position > 0) {
3490 *fanout_input = StrCat(rep->name(), ":", position);
3491 } else if (position == 0) {
3492 *fanout_input = rep->name();
3493 } else {
3494 *fanout_input = StrCat("^", rep->name());
3495 }
3496 node_map_->AddOutput(rep->name(), fanout->name());
3497 }
3498 }
3499 duplicates.insert(i);
3500 stop = false;
3501 }
3502 } while (!stop);
3503
3504 // Delete duplicates
3505 if (fetch_nodes_known_ && !duplicates.empty()) {
3506 EraseNodesFromGraph(duplicates, optimized_graph_);
3507 // Rebuild the NodeMap which was invalidated by the node swapping above.
3508 node_map_.reset(new NodeMap(optimized_graph_));
3509 }
3510 }
3511
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)3512 void ArithmeticOptimizer::ForwardControlDependencies(
3513 NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
3514 for (const auto& src : src_nodes) {
3515 for (int i = src->input_size() - 1; i >= 0; --i) {
3516 if (IsControlInput(src->input(i))) {
3517 *target_node->add_input() = src->input(i);
3518 node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
3519 } else {
3520 break;
3521 }
3522 }
3523 }
3524 DedupControlInputs(target_node);
3525 }
3526
SimplifyArithmeticOps(bool can_use_shapes)3527 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
3528 SetVector<NodeDef*> nodes_to_simplify;
3529 nodes_to_simplify.Reserve(optimized_graph_->node_size());
3530 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
3531 nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
3532 }
3533
3534 const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
3535 graph_properties_.get(), node_map_.get(),
3536 &feed_nodes_, opt_level_);
3537 const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
3538
3539 // Stop pipeline after first stage returning non-empty simplified tensor name.
3540 const auto stop = [](const string& result) { return !result.empty(); };
3541 GraphOptimizerStagePipeline<string> pipeline(stop);
3542
3543 if (options_.combine_add_to_addn && can_use_shapes)
3544 pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
3545 if (options_.fold_conjugate_into_transpose)
3546 pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
3547 if (options_.fold_multiply_into_conv)
3548 pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
3549 if (options_.fold_transpose_into_matmul)
3550 pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
3551 if (options_.hoist_common_factor_out_of_aggregation && can_use_shapes)
3552 pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
3553 if (options_.minimize_broadcasts && can_use_shapes)
3554 pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
3555 if (options_.remove_identity_transpose && can_use_shapes)
3556 pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
3557 if (options_.remove_involution)
3558 pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
3559 if (options_.remove_redundant_bitcast)
3560 pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
3561 if (options_.remove_redundant_cast)
3562 pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
3563 if (options_.remove_redundant_reshape)
3564 pipeline.AddStage<RemoveRedundantReshape>(ctx, ctx_ext);
3565 if (options_.remove_negation)
3566 pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
3567 if (options_.replace_mul_with_square)
3568 pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
3569 if (options_.remove_logical_not)
3570 pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
3571 if (options_.reorder_cast_like_and_value_preserving)
3572 pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
3573 if (options_.simplify_aggregation)
3574 pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
3575 if (options_.hoist_cwise_unary_chains)
3576 pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
3577 if (options_.convert_sqrt_div_to_rsqrt_mul)
3578 pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
3579 if (options_.remove_idempotent)
3580 pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
3581 if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
3582 if (options_.convert_log1p)
3583 pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
3584 if (options_.convert_log_softmax)
3585 pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
3586 if (options_.optimize_max_or_min_of_monotonic)
3587 pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
3588 if (options_.convert_expm1)
3589 pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
3590 if (options_.unary_ops_composition)
3591 pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
3592 if (options_.remove_stack_strided_slice_same_axis)
3593 pipeline.AddStage<RemoveStackStridedSliceSameAxis>(ctx, ctx_ext);
3594 if (options_.fuse_squared_diff)
3595 pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
3596
3597 VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
3598 << str_util::Join(pipeline.StageNames(), ", ");
3599
3600 while (!nodes_to_simplify.Empty()) {
3601 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3602 NodeDef* node = nodes_to_simplify.PopBack();
3603
3604 string simplified_tensor = "";
3605 bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
3606
3607 // If the node was not optimized by any of the stages, go to the next one.
3608 if (!optimized) continue;
3609
3610 // re-wire consumers of an old node to the new one
3611 if (NodeName(simplified_tensor) != node->name()) {
3612 // Always consider simplified_tensor for further optimizations.
3613 NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
3614 if (simplified_node != nullptr) {
3615 nodes_to_simplify.PushBack(simplified_node);
3616 }
3617 // When `node` is simplified to another node rather than in-place, the
3618 // consumers of `node` are already redirected to `simplified_tensor`.
3619 // Re-push the consumers into `nodes_to_simplify` for further
3620 // optimizations.
3621 const std::set<NodeDef*> outputs = node_map_->GetOutputs(node->name());
3622 std::vector<NodeDef*> consumers(outputs.begin(), outputs.end());
3623 std::sort(consumers.begin(), consumers.end(),
3624 [](const NodeDef* n1, const NodeDef* n2) {
3625 return n1->name() < n2->name();
3626 });
3627 for (NodeDef* consumer : consumers) {
3628 // Update `consumer`'s use of `node` to `input`'s operand.
3629 for (int i = 0; i < consumer->input_size(); ++i) {
3630 int operand_pos;
3631 string operand_node_name =
3632 ParseNodeName(consumer->input(i), &operand_pos);
3633 if (operand_node_name == node->name()) {
3634 *consumer->mutable_input(i) =
3635 (operand_pos < 0
3636 ? AsControlDependency(NodeName(simplified_tensor))
3637 : simplified_tensor);
3638 }
3639 }
3640 node_map_->UpdateInput(consumer->name(), node->name(),
3641 simplified_tensor);
3642 nodes_to_simplify.PushBack(consumer);
3643 }
3644 }
3645 }
3646 return Status::OK();
3647 }
3648
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)3649 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
3650 const GrapplerItem& item,
3651 GraphDef* optimized_graph) {
3652 // Set up helper data structures.
3653 nodes_to_preserve_ = item.NodesToPreserve();
3654 fetch_nodes_known_ = !item.fetch.empty();
3655 GrapplerItem optimized_item(item);
3656 optimized_graph_ = &optimized_item.graph;
3657 node_map_.reset(new NodeMap(optimized_graph_));
3658
3659 for (const auto& feed : item.feed) {
3660 feed_nodes_.insert(NodeName(feed.first));
3661 }
3662
3663 // Disable restricted graph rewrites.
3664 options_.unary_ops_composition &=
3665 item.optimization_options().allow_non_differentiable_rewrites;
3666
3667 if (options_.dedup_computations) {
3668 DedupComputations();
3669 }
3670 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
3671
3672 // Perform topological sort on the graph in order to help AddOpsRewrite to
3673 // optimize larger subgraphs starting from the roots with more inputs.
3674 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
3675
3676 graph_properties_.reset(new GraphProperties(optimized_item));
3677 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
3678 const Status status = graph_properties_->InferStatically(assume_valid_feeds);
3679 const bool can_use_shapes = status.ok();
3680 if (!can_use_shapes) {
3681 VLOG(1) << "Shape inference failed." << status.error_message();
3682 }
3683
3684 // Perform the optimizations.
3685 TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
3686
3687 optimized_graph->Swap(optimized_graph_);
3688 return Status::OK();
3689 }
3690
Feedback(Cluster *,const GrapplerItem &,const GraphDef &,double)3691 void ArithmeticOptimizer::Feedback(Cluster* /*cluster*/,
3692 const GrapplerItem& /*item*/,
3693 const GraphDef& /*optimized_graph*/,
3694 double /*result*/) {
3695 // Nothing to do for ArithmeticOptimizer.
3696 }
3697
3698 } // namespace grappler
3699 } // namespace tensorflow
3700