1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/graph_topology_view.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
42 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
43 #include "tensorflow/core/grappler/utils.h"
44 #include "tensorflow/core/grappler/utils/canonicalizer.h"
45 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
46 #include "tensorflow/core/grappler/utils/topological_sort.h"
47 #include "tensorflow/core/grappler/utils/traversal.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/hash/hash.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/tensor_coding.h"
56 #include "tensorflow/core/protobuf/error_codes.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/saved_tensor_slice_util.h"
59 #include "tensorflow/core/util/strided_slice_op.h"
60
61 using tensorflow::strings::StrCat;
62
63 namespace tensorflow {
64 namespace grappler {
65 namespace {
66
67 // Mark nodes created or optimized by a stage with a tag.
68 constexpr char kAddOpsRewriteTag[] =
69 "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
70 constexpr char kMinimizeBroadcastsTag[] =
71 "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
72
73 // Extract values from a Const op to `values`. Returns true if succeeds.
74 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)75 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
76 if (node.op() != "Const") {
77 return false;
78 }
79
80 if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
81 node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
82 return false;
83 }
84
85 // TensorProto represents the content of the tensor in either <type>_val or
86 // tensor_content.
87 const TensorProto& tensor = node.attr().at("value").tensor();
88 typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
89 checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
90
91 if (!tensor_values->empty() && tensor.has_tensor_shape()) {
92 // When tensor_shape is set, theoretically the representation of the data
93 // could be compressed. So, before copying values to the returned vector,
94 // make sure no compression happens.
95 const TensorShapeProto& shape = tensor.tensor_shape();
96 if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
97 values->insert(values->end(), tensor_values->begin(),
98 tensor_values->end());
99 return true;
100 }
101 }
102
103 const auto tensor_content_size = tensor.tensor_content().size();
104 if (tensor_content_size > 0) {
105 CHECK_EQ(0, tensor_content_size % sizeof(T))
106 << "tensor_content_size (" << tensor_content_size
107 << ") is not a multiple of " << sizeof(T);
108 values->resize(tensor_content_size / sizeof(T));
109 port::CopyToArray(tensor.tensor_content(),
110 reinterpret_cast<char*>(values->data()));
111 return true;
112 }
113
114 return false;
115 }
116
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)117 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
118 GraphDef* graph, NodeMap* node_map) {
119 bool already_exists = false;
120 for (const string& input : node->input()) {
121 if (input == new_input || AsControlDependency(input) == new_input) {
122 already_exists = true;
123 break;
124 }
125 }
126 if (!already_exists) {
127 const string ctrl_dep =
128 ConstantFolding::AddControlDependency(new_input, graph, node_map);
129 node->add_input(ctrl_dep);
130 node_map->AddOutput(NodeName(new_input), node->name());
131 }
132 return !already_exists;
133 }
134
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)135 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
136 (*node->mutable_attr())[attr_name].set_type(dtype);
137 }
138
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)139 NodeDef* GetTailOfValuePreservingChain(
140 const NodeDef& node, const NodeMap& node_map,
141 const std::unordered_set<string>& nodes_to_preserve) {
142 auto is_value_preserving_non_branching = [&](const NodeDef& node) {
143 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
144 IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
145 };
146 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
147 is_value_preserving_non_branching);
148 }
149
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)150 NodeDef* GetTailOfIdempotentChain(
151 const NodeDef& node, const NodeMap& node_map,
152 const std::unordered_set<string>& nodes_to_preserve) {
153 auto is_idempotent_non_branching = [&](const NodeDef& node) {
154 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
155 IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
156 };
157 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
158 is_idempotent_non_branching);
159 }
160
161 // GetElementUnexhaustive tries to get the value of an element in a tensor and
162 // turn it into complex128 type. It only check for a limited number of data
163 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)164 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
165 complex128* element) {
166 if (dtypes.find(t.dtype()) == dtypes.end()) return false;
167 switch (t.dtype()) {
168 case DT_BFLOAT16:
169 *element = complex128(t.flat<bfloat16>()(i));
170 return true;
171 case DT_HALF:
172 *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
173 return true;
174 case DT_INT32:
175 *element = complex128(t.flat<int32>()(i));
176 return true;
177 case DT_INT64:
178 *element = complex128(t.flat<int64>()(i));
179 return true;
180 case DT_FLOAT:
181 *element = complex128(t.flat<float>()(i));
182 return true;
183 case DT_DOUBLE:
184 *element = complex128(t.flat<double>()(i));
185 return true;
186 case DT_COMPLEX64:
187 *element = complex128(t.flat<complex64>()(i));
188 return true;
189 case DT_COMPLEX128:
190 *element = t.flat<complex128>()(i);
191 return true;
192 default:
193 return false;
194 }
195 }
196
NodeIsOnCpu(const NodeDef & node)197 bool NodeIsOnCpu(const NodeDef& node) {
198 string task;
199 string device;
200 return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
201 absl::StrContains(device, DEVICE_CPU);
202 }
203
204 // True if all regular (non-control) inputs reference the same node or if there
205 // are no non-control inputs
AllRegularInputsEqual(const NodeDef & node)206 bool AllRegularInputsEqual(const NodeDef& node) {
207 if (!HasRegularInputs(node)) return true;
208 for (int i = 1; i < node.input_size(); ++i) {
209 if (IsControlInput(node.input(i))) {
210 break;
211 }
212 if (node.input(0) != node.input(i)) {
213 return false;
214 }
215 }
216 return true;
217 }
218
219 // Replace a node with NoOp and reset shape inference results for it..
ReplaceWithNoOp(NodeDef * node,const GraphOptimizerContext & ctx)220 void ReplaceWithNoOp(NodeDef* node, const GraphOptimizerContext& ctx) {
221 ctx.node_map->RemoveInputs(node->name());
222 ctx.graph_properties->ClearInputProperties(node->name());
223 ctx.graph_properties->ClearOutputProperties(node->name());
224 EraseRegularNodeAttributes(node);
225 node->set_op("NoOp");
226 node->clear_input();
227 }
228
229 // Graph optimizer context extension specific to ArithmeticOptimizer.
230 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon912b57ca0111::ArithmeticOptimizerContext231 explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
232 : nodes_to_simplify(nodes_to_simplify) {}
233 SetVector<NodeDef*>* nodes_to_simplify;
234 };
235
236 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
237 // AddOps optimization, etc...
238 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
239 public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)240 explicit ArithmeticOptimizerStage(const string& name,
241 const GraphOptimizerContext& ctx,
242 const ArithmeticOptimizerContext ctx_ext)
243 : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
244 ctx_ext_(ctx_ext) {}
245 ~ArithmeticOptimizerStage() override = default;
246
247 protected:
248 // Simplification graph rewrite can create additional nodes that are inputs
249 // to final simplified node, they can be also added to the arithmetic
250 // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)251 void AddToOptimizationQueue(NodeDef* node) {
252 ctx_ext_.nodes_to_simplify->PushBack(node);
253 }
254
255 // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)256 Status UpdateConsumers(NodeDef* node, const string& new_input) {
257 const auto consumers = ctx().node_map->GetOutputs(node->name());
258 if (consumers.empty()) return Status::OK();
259 const TensorId new_tensor = ParseTensorName(new_input);
260 for (NodeDef* consumer : consumers) {
261 if (consumer->name() == new_tensor.node()) continue;
262 bool updated = false;
263 for (int i = 0; i < consumer->input_size(); ++i) {
264 const TensorId input_tensor = ParseTensorName(consumer->input(i));
265 if (input_tensor.node() == node->name()) {
266 if (new_tensor.index() < 0 && input_tensor.index() >= 0) {
267 // Overwriting a data input with a control input will make the graph
268 // invalid.
269 return errors::InvalidArgument(
270 "Cannot override data input ", input_tensor.ToString(),
271 " with control input ", new_tensor.ToString());
272 }
273 consumer->set_input(i, input_tensor.index() < 0
274 ? absl::StrCat("^", new_tensor.node())
275 : new_input);
276 ctx().node_map->UpdateInput(consumer->name(), node->name(),
277 new_input);
278 updated = true;
279 }
280 }
281 if (updated) {
282 DedupControlInputs(consumer);
283 AddToOptimizationQueue(consumer);
284 }
285 }
286 return Status::OK();
287 }
288
289 // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
290 // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)291 void ForwardControlDependencies(
292 NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
293 for (const auto& src : src_nodes) {
294 for (int i = src->input_size() - 1; i >= 0; --i) {
295 if (IsControlInput(src->input(i))) {
296 *target_node->add_input() = src->input(i);
297 ctx().node_map->AddOutput(NodeName(src->input(i)),
298 target_node->name());
299 } else {
300 break;
301 }
302 }
303 }
304 DedupControlInputs(target_node);
305 }
306
IsReallyConstant(const NodeDef & node) const307 bool IsReallyConstant(const NodeDef& node) const {
308 if (!IsConstant(node)) {
309 return false;
310 }
311 // If the node is fed it's not constant anymore.
312 return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
313 }
314
IsInPreserveSet(const NodeDef & node) const315 bool IsInPreserveSet(const NodeDef& node) const {
316 return ctx().nodes_to_preserve->find(node.name()) !=
317 ctx().nodes_to_preserve->end();
318 }
319
320 // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const321 bool IsDrivenByControlDependency(const NodeDef& node) const {
322 return std::any_of(
323 node.input().begin(), node.input().end(),
324 [](const string& input) { return IsControlInput(input); });
325 }
326
327 // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const328 bool DrivesControlDependency(const NodeDef& node) const {
329 for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
330 for (int i = 0; i < output->input_size(); ++i) {
331 const TensorId tensor = ParseTensorName(output->input(i));
332 if (tensor.node() == node.name() && tensor.index() < 0) {
333 return true;
334 }
335 }
336 }
337 return false;
338 }
339
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)340 bool GetTensorFromConstNode(const string& node_name_or_input,
341 Tensor* tensor) {
342 const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
343 return node != nullptr && IsReallyConstant(*node) &&
344 CheckAttrExists(*node, "value").ok() &&
345 tensor->FromProto(node->attr().at("value").tensor());
346 }
347
348 private:
349 // Extended context required for ArithmeticOptimizer.
350 const ArithmeticOptimizerContext ctx_ext_;
351 };
352
353 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
354 // group of nodes from the optimized graph.
355 //
356 // * AddOpsRewrite:
357 // Rewrite a group of Add/AddN with compact Add/AddN tree
358 //
359 // * MinimizeBroadcasts:
360 // Rewrite a group of binary associative ops, reordering
361 // inputs, to minimize the cost of broadcast
362 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
363 public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)364 explicit ArithmeticNodesGroupOptimizerStage(
365 const string& name, const GraphOptimizerContext& ctx,
366 const ArithmeticOptimizerContext ctx_ext)
367 : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
368 ~ArithmeticNodesGroupOptimizerStage() override = default;
369
370 // Input name with a statically inferred shape from GraphProperties
371 struct InputAndShape {
InputAndShapetensorflow::grappler::__anon912b57ca0111::ArithmeticNodesGroupOptimizerStage::InputAndShape372 InputAndShape(const string& input, const TensorShapeProto& shape)
373 : input(input), shape(shape) {}
374 string input;
375 TensorShapeProto shape;
376 };
377
378 // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
379 // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
380 // obtained by graph traversal, starting from a root node.
381 struct OptimizedNodesGroup {
382 NodeDef* root_node;
383 TensorShapeProto root_shape;
384 // Optimized nodes that will be updated or removed by rewrite
385 std::vector<NodeDef*> optimized_nodes;
386 // Inputs to optimized nodes
387 std::vector<InputAndShape> inputs;
388 };
389
TrySimplify(NodeDef * node,string * simplified_node_name)390 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
391 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
392
393 OptimizedNodesGroup group;
394 TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
395
396 if (!group.optimized_nodes.empty()) {
397 *simplified_node_name = RewriteOptimizedNodesGroup(group);
398 }
399
400 return Status::OK();
401 }
402
403 protected:
404 // Modify the optimized graph after nodes group was successfully identified
405 virtual string RewriteOptimizedNodesGroup(
406 const OptimizedNodesGroup& group) = 0;
407
408 // Check if input can become a part of current optimized nodes group.
409 virtual bool IsAbsorbableByOptimizedNodesGroup(
410 const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
411
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const412 Status AbsorbInputByOptimizedNodesGroup(const string& input,
413 OptimizedNodesGroup* group) const {
414 std::deque<const string*> input_tensors;
415 input_tensors.push_front(&input);
416
417 while (!input_tensors.empty()) {
418 const string* input_tensor = input_tensors.front();
419 input_tensors.pop_front();
420
421 // Get a node for the input tensor.
422 NodeDef* input_node;
423 TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
424
425 if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
426 group->optimized_nodes.push_back(input_node);
427 for (int i = input_node->input_size() - 1; i >= 0; --i) {
428 const string& absorbed_node_input = input_node->input(i);
429 // TODO(ezhulenev): support control inputs
430 if (IsControlInput(absorbed_node_input)) continue;
431 input_tensors.push_front(&absorbed_node_input);
432 }
433 } else {
434 // If input node can't be absorbed, add it to OptimizedNodesGroup input.
435 const OpInfo::TensorProperties* properties;
436 TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
437 group->inputs.emplace_back(*input_tensor, properties->shape());
438 }
439 }
440
441 return Status::OK();
442 }
443
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const444 Status CreateOptimizedNodesGroup(NodeDef* root_node,
445 OptimizedNodesGroup* group) const {
446 const OpInfo::TensorProperties* root_node_output_properties;
447 TF_RETURN_IF_ERROR(
448 GetTensorProperties(root_node->name(), &root_node_output_properties));
449
450 group->root_node = root_node;
451 group->root_shape = root_node_output_properties->shape();
452
453 group->optimized_nodes.reserve(root_node->input_size());
454 for (int i = 0; i < root_node->input_size(); ++i) {
455 const string& input_i = root_node->input(i);
456 // TODO(ezhulenev): add support for control inputs
457 if (IsControlInput(input_i)) continue;
458 TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
459 }
460
461 return Status::OK();
462 }
463
464 // Check if all inputs can be broadcasted to the same shape
465 // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const466 bool HasAllInputsBroadcastableToShape(
467 const NodeDef& node, const OpInfo::TensorProperties& properties) const {
468 auto is_broadcastable = [this, &properties](const string& input) {
469 const OpInfo::TensorProperties* input_props;
470 Status has_input_properties = GetTensorProperties(input, &input_props);
471 return has_input_properties.ok() &&
472 ShapesBroadcastable(properties, *input_props);
473 };
474 return std::all_of(node.input().begin(), node.input().end(),
475 is_broadcastable);
476 }
477
ShapeSignature(const TensorShapeProto & shape) const478 string ShapeSignature(const TensorShapeProto& shape) const {
479 string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
480 for (int i = 0; i < shape.dim_size(); ++i)
481 strings::StrAppend(&signature, ":", shape.dim(i).size());
482 return signature;
483 }
484
MarkWithTag(const StringPiece tag,NodeDef * node)485 void MarkWithTag(const StringPiece tag, NodeDef* node) {
486 AddNodeAttr(tag, true, node);
487 }
488
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const489 void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
490 const StringPiece tag) const {
491 AddNodeAttr(tag, true, group.root_node);
492 for (NodeDef* optimized_node : group.optimized_nodes) {
493 AddNodeAttr(tag, true, optimized_node);
494 }
495 }
496
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const497 bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
498 const NodeDef& node) const {
499 return group.root_node->device() == node.device();
500 }
501
IsInPreserveSet(const NodeDef & node) const502 bool IsInPreserveSet(const NodeDef& node) const {
503 return ctx().nodes_to_preserve->find(node.name()) !=
504 ctx().nodes_to_preserve->end();
505 }
506
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const507 bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
508 return HasNodeAttr(node, tag);
509 }
510
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const511 bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
512 const StringPiece tag2) const {
513 return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
514 }
515 };
516
517 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
518 // original inputs of absorbed nodes.
519 //
520 // 1) All nodes must have the same device placement.
521 //
522 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
523 // optimized to a single AddN node.
524 //
525 // AddN_1
526 // / | \
527 // Add_1 z Add_2 -> AddN(x, y, z, w, q, e)
528 // / \ / \
529 // x y w Add_3
530 // / \
531 // q e
532 //
533 // 3) If some nodes have different shape (it needs to be broadcastable to the
534 // shape of a "root), tree is optimized to AddNs for symbolically equal
535 // shapes, and a tree of Add ops, that minimize broadcasts.
536 //
537 // AddN_1 Add
538 // / | \ / \
539 // Add_1 z Add_2 -> Add w
540 // / \ / \ / \
541 // x y w Add_3 AddN(x, y, q, e) z
542 // / \
543 // q e
544 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
545 public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)546 explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
547 const ArithmeticOptimizerContext& ctx_ext)
548 : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
549 ~AddOpsRewriteStage() override = default;
550
551 // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const552 bool IsSupported(const NodeDef* node) const override {
553 if (!CanOptimize(*node)) return false;
554
555 // shape must be symbolically defined and all inputs compatible with it
556 const OpInfo::TensorProperties* properties;
557 Status has_properties = GetTensorProperties(node->name(), &properties);
558 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
559 HasAllInputsBroadcastableToShape(*node, *properties);
560 }
561
562 protected:
563 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const564 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
565 const NodeDef& node) const override {
566 if (!CanOptimize(node)) return false;
567
568 if (!IsOnTheSameDevice(group, node)) {
569 return false;
570 }
571 // with a single output data consumer (presumably if we reach this node from
572 // previously absorbed or a root node, it means that this node is not used
573 // as an input to any other op, outside of the group)
574 if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
575 return false;
576 }
577 // All input shapes must be broadcastable to the node shape
578 const OpInfo::TensorProperties* properties;
579 Status has_properties = GetTensorProperties(node.name(), &properties);
580 return has_properties.ok() &&
581 HasAllInputsBroadcastableToShape(node, *properties);
582 }
583
584 // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const585 bool CanOptimize(const NodeDef& node) const {
586 // TODO(ezhulenev): check if AccumulateNV2 can be supported too
587 if (!IsAdd(node) && !IsAddN(node)) {
588 return false;
589 }
590 if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
591 return false;
592 }
593 // TODO(ezhulenev): relax this condition for root node
594 return !(IsDrivenByControlDependency(node) ||
595 DrivesControlDependency(node));
596 }
597
598 // Rewrite a group of add ops into a single AddN if all input shapes are
599 // symbolically equal. If not, create AddN for equal shapes first, and then
600 // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)601 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
602 VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
603 << " op=" << group.root_node->op()
604 << " num_optimized_nodes=" << group.optimized_nodes.size()
605 << " num_inputs=" << group.inputs.size();
606
607 // Do not optimize any of the nodes that are part of this group.
608 MarkAllMembersWithTag(group, kAddOpsRewriteTag);
609
610 // All new nodes will be placed under the scope of a root node.
611 auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
612
613 // Find what shapes are present in the inputs of absorbed nodes.
614 std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
615 for (const auto& input : group.inputs) {
616 shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
617 }
618
619 using SigKV = decltype(shape_sig_to_inputs)::value_type;
620 VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
621 << " unique shapes: "
622 << absl::StrJoin(shape_sig_to_inputs, ", ",
623 [](string* out, SigKV p) {
624 strings::StrAppend(out, p.first);
625 });
626
627 // Collect all the shapes from representative elements.
628 std::vector<TensorShapeProto> shapes;
629 shapes.reserve(shape_sig_to_inputs.size());
630 for (const auto& el : shape_sig_to_inputs)
631 shapes.push_back(el.second[0].shape);
632
633 // If all inputs have the same shape, rewrite whole group with a single AddN
634 if (shapes.size() == 1) {
635 string node_name = UniqueOptimizedNodeName(root_scope_and_name);
636 AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
637 group.inputs);
638 return node_name;
639 }
640
641 // For inputs of different shapes:
642 // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
643 // 2. Build a tree of Add nodes, minimizing cost of broadcast
644 std::sort(shapes.begin(), shapes.end(),
645 [](const TensorShapeProto& left, const TensorShapeProto& right) {
646 return CompareSymbolicallyShapedTensorSizes(left, right);
647 });
648
649 // optimized name for leaf AddN nodes
650 auto leaf_node_name = [&root_scope_and_name, this](int i) {
651 return UniqueOptimizedNodeName(root_scope_and_name,
652 strings::StrCat("Leaf_", i));
653 };
654 // optimized name for internal nodes of a tree built up from AddN leaves
655 auto internal_node_name = [&root_scope_and_name, this](int i) {
656 return UniqueOptimizedNodeName(root_scope_and_name,
657 strings::StrCat("Internal_", i));
658 };
659
660 // Add/AddN nodes that must be added to the tree
661 std::deque<InputAndShape> add_ops;
662
663 // Prepare leaf AddN nodes for inputs of equal shape
664 for (int i = 0, end = shapes.size(); i < end; ++i) {
665 const auto node_name = leaf_node_name(i);
666 const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
667 add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
668 node_name, inputs));
669 }
670
671 // Build up a tree of Add ops
672 int internal_nodes = 0;
673 do {
674 const InputAndShape lhs = add_ops.front();
675 add_ops.pop_front();
676 const InputAndShape rhs = add_ops.front();
677 add_ops.pop_front();
678 string name = add_ops.empty()
679 ? UniqueOptimizedNodeName(root_scope_and_name)
680 : internal_node_name(internal_nodes++);
681 InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
682 add_ops.push_front(add);
683 } while (add_ops.size() > 1);
684
685 InputAndShape optimized_root_node = add_ops.front();
686 return optimized_root_node.input;
687 }
688
689 // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)690 InputAndShape AddInputsOfSymbolicallyEqualShape(
691 const NodeDef& root_node, const string& node_name,
692 const std::vector<InputAndShape>& inputs) {
693 CHECK(!inputs.empty()) << "Inputs must be non-empty";
694
695 // Do not create redundant AddN nodes
696 if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
697 return inputs[0];
698 }
699
700 // get shape from representative element
701 auto shape = inputs[0].shape;
702
703 // copy attributes from a root node
704 DataType dtype = root_node.attr().at("T").type();
705
706 // add new AddN node
707 NodeDef* node = AddEmptyNode(node_name);
708 node->set_op("AddN");
709 node->set_device(root_node.device());
710 (*node->mutable_attr())["T"].set_type(dtype);
711 (*node->mutable_attr())["N"].set_i(inputs.size());
712
713 for (const auto& inputAndShape : inputs) {
714 ctx().node_map->AddOutput(inputAndShape.input, node_name);
715 node->add_input(inputAndShape.input);
716 }
717
718 MarkWithTag(kAddOpsRewriteTag, node);
719 return InputAndShape(node_name, shape);
720 }
721
722 // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)723 InputAndShape AddAggregatedInputs(const NodeDef& root_node,
724 const string& node_name,
725 const InputAndShape& left,
726 const InputAndShape& right) {
727 // copy attributes from a root node
728 DataType dtype = root_node.attr().at("T").type();
729
730 // add new Add node
731 NodeDef* node = AddEmptyNode(node_name);
732 node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
733 : "AddV2");
734 node->set_device(root_node.device());
735 (*node->mutable_attr())["T"].set_type(dtype);
736 node->add_input(left.input);
737 node->add_input(right.input);
738
739 ctx().node_map->AddOutput(left.input, node_name);
740 ctx().node_map->AddOutput(right.input, node_name);
741
742 MarkWithTag(kAddOpsRewriteTag, node);
743 return InputAndShape(
744 node_name, TensorShapeProto()); // shape is not important at this point
745 }
746 };
747
748 // Use the distributive property of multiplication and division over addition,
749 // along with commutativity of the former, to hoist common factors/denominators
750 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
751 // This pattern occurs frequently in regularization terms for the gradients
752 // during training.
753 //
754 // For example, we can rewrite an expression of the form:
755 // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
756 // to the following:
757 // Mul(x, AddN(y1, y2, y3, ... yn))
758 // For division, we can rewrite
759 // AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
760 // to:
761 // Div(AddN(y1, y2, y3, ... yn), x)
762 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
763 public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)764 explicit HoistCommonFactorOutOfAggregation(
765 const GraphOptimizerContext& ctx,
766 const ArithmeticOptimizerContext& ctx_ext)
767 : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
768 ~HoistCommonFactorOutOfAggregation() override = default;
769
IsSupported(const NodeDef * node) const770 bool IsSupported(const NodeDef* node) const override {
771 return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
772 !IsRewritten(node);
773 }
774
TrySimplify(NodeDef * node,string * simplified_node_name)775 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
776 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
777
778 bool common_factor_is_denominator = false;
779 std::set<string> common_factors;
780 std::vector<string> ctrl_deps;
781 TF_RETURN_IF_ERROR(GetCommonFactors(
782 node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
783
784 if (common_factors.size() == 1) {
785 const string& common_factor = *common_factors.begin();
786
787 // Gather up the non-shared factors
788 bool shapes_match = true;
789 std::vector<string> unique_factors;
790 TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
791 common_factor_is_denominator,
792 &shapes_match, &unique_factors));
793
794 if (shapes_match) {
795 NodeDef* input_0;
796 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
797
798 // Use a copy of the first node for the outer multiplication/division.
799 NodeDef* new_outer_node = AddCopyNode(
800 OuterNodeName(node, common_factor_is_denominator), input_0);
801 // And a copy of aggregation node as one of the inner operands
802 NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
803
804 new_outer_node->set_device(node->device());
805 if (common_factor_is_denominator) {
806 new_outer_node->set_input(0, new_add_node->name());
807 new_outer_node->set_input(1, common_factor);
808 } else {
809 new_outer_node->set_input(0, common_factor);
810 new_outer_node->set_input(1, new_add_node->name());
811 }
812
813 ctx().node_map->AddOutput(common_factor, new_outer_node->name());
814 ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
815
816 // Hoist non-shared factors up into the new AddN node.
817 for (int i = 0, end = unique_factors.size(); i < end; ++i) {
818 const string& unique_factor_i = unique_factors[i];
819 new_add_node->set_input(i, unique_factor_i);
820 ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
821 }
822
823 // Add control deps on add node
824 for (const string& ctrl_dep : ctrl_deps) {
825 *new_add_node->add_input() = ctrl_dep;
826 ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
827 }
828
829 // optimize new inner aggregation node
830 AddToOptimizationQueue(new_add_node);
831 // do not optimize the same node twice
832 rewritten_nodes_.insert(node->name());
833 *simplified_node_name = new_outer_node->name();
834 }
835 }
836 return Status::OK();
837 }
838
839 private:
840 // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const841 string OuterNodeName(const NodeDef* node, bool is_div) const {
842 auto scope_and_name = ParseNodeScopeAndName(node->name());
843 return is_div ? OptimizedNodeName(scope_and_name, "Div")
844 : OptimizedNodeName(scope_and_name, "Mul");
845 }
846
847 // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const848 string InnerAddNodeName(const NodeDef* node) const {
849 auto scope_and_name = ParseNodeScopeAndName(node->name());
850 return OptimizedNodeName(scope_and_name, "AddV2");
851 }
852
853 // Determine the set of common factors if the input nodes are all Mul or
854 // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const855 Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
856 bool* common_factor_is_denominator,
857 std::vector<string>* ctrl_deps) const {
858 CHECK(common_factors->empty());
859 CHECK_NOTNULL(common_factor_is_denominator);
860 *common_factor_is_denominator = false;
861
862 bool has_mul = false;
863 bool has_div = false;
864 for (int i = 0; i < node->input_size(); ++i) {
865 if (i > 0 && common_factors->empty()) break;
866 if (IsControlInput(node->input(i))) {
867 ctrl_deps->push_back(node->input(i));
868 continue;
869 }
870 NodeDef* input;
871 TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
872
873 if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
874 (IsAnyDiv(*input) && has_mul)) {
875 // Break if input is neither a Mul or Div, or if there are both Mul &
876 // Div Ops.
877 common_factors->clear();
878 break;
879 } else if (IsAnyDiv(*input)) {
880 has_div = true;
881 // In case of possible common dividers, we avoid hoisting out if any
882 // input is not float/double, since integer division is not distributive
883 // over addition.
884 const OpInfo::TensorProperties* properties0;
885 const OpInfo::TensorProperties* properties1;
886 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
887 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
888 if (properties0->dtype() != DT_FLOAT &&
889 properties0->dtype() != DT_DOUBLE &&
890 properties1->dtype() != DT_FLOAT &&
891 properties1->dtype() != DT_DOUBLE) {
892 common_factors->clear();
893 break;
894 }
895 } else if (IsMul(*input)) {
896 has_mul = true;
897 }
898
899 // We only focus on common factors from denominators if any Op is a
900 // Div.
901 std::set<string> factors_i =
902 has_mul ? std::set<string>{input->input(0), input->input(1)}
903 : std::set<string>{input->input(1)};
904 if (i == 0) {
905 std::swap(*common_factors, factors_i);
906 } else {
907 std::set<string> intersection;
908 std::set_intersection(
909 factors_i.begin(), factors_i.end(), common_factors->begin(),
910 common_factors->end(),
911 std::inserter(intersection, intersection.begin()));
912 std::swap(*common_factors, intersection);
913 }
914 for (int i = 2; i < input->input_size(); ++i) {
915 ctrl_deps->push_back(input->input(i));
916 }
917 }
918
919 *common_factor_is_denominator = has_div;
920 return Status::OK();
921 }
922
923 // Gather up the non-shared factors (the y's in the example).
924 // Unless the aggregation is Add, we have to make sure that all the y's
925 // have the same shape since the other aggregation ops do not support
926 // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const927 Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
928 const bool common_factor_is_denominator,
929 bool* shapes_match,
930 std::vector<string>* unique_factors) const {
931 *shapes_match = true;
932 unique_factors->reserve(node->input_size());
933
934 for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
935 const string& input = node->input(i);
936 if (IsControlInput(input)) {
937 break;
938 }
939 NodeDef* inner_node;
940 TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
941 const int unique_factor_index =
942 common_factor_is_denominator
943 ? 0
944 : (inner_node->input(0) == common_factor ? 1 : 0);
945 unique_factors->push_back(inner_node->input(unique_factor_index));
946 if (i > 0 && !IsAdd(*node)) {
947 const OpInfo::TensorProperties* lhs;
948 const OpInfo::TensorProperties* rhs;
949 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
950 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
951 *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
952 }
953 }
954 return Status::OK();
955 }
956
IsRewritten(const NodeDef * node) const957 bool IsRewritten(const NodeDef* node) const {
958 // if graph rewrite happens in multiple passes without graph pruning between
959 // them, it's possible that rewritten node already exists in a graph
960 return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
961 ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
962 ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
963 ctx().node_map->NodeExists(InnerAddNodeName(node));
964 }
965
966 // keep names of the nodes that were optimized by this stage
967 std::unordered_set<string> rewritten_nodes_;
968 };
969
970 // Binary associative ops can be re-ordered to minimize the number of broadcasts
971 // and the size of a temporary tensors.
972 //
973 // Example: [a, c] - scalars, [b, d] - matrices
974 // @ - binary associative op (Add or Mul)
975 // @* - broadcast
976 //
977 // @ @*
978 // / \ / \
979 // @* @* -> @ @
980 // / \ / \ / \ / \
981 // a b c d a c b d
982 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
983 public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)984 explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
985 const ArithmeticOptimizerContext& ctx_ext)
986 : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
987 }
988 ~MinimizeBroadcasts() override = default;
989
IsSupported(const NodeDef * node) const990 bool IsSupported(const NodeDef* node) const override {
991 if (!IsBinaryAssociative(*node)) return false;
992
993 if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
994 return false;
995
996 // has a symbolically defined shape with broadcastable inputs
997 const OpInfo::TensorProperties* properties;
998 Status has_properties = GetTensorProperties(node->name(), &properties);
999 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
1000 HasAllInputsBroadcastableToShape(*node, *properties);
1001 }
1002
1003 protected:
IsBinaryAssociative(const NodeDef & node) const1004 bool IsBinaryAssociative(const NodeDef& node) const {
1005 return IsMul(node) || IsAdd(node);
1006 }
1007
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const1008 bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
1009 return group.root_node->op() == node.op();
1010 }
1011
1012 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const1013 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
1014 const NodeDef& node) const override {
1015 if (!IsSameOp(group, node)) {
1016 return false;
1017 }
1018 if (IsInPreserveSet(node)) {
1019 return false;
1020 }
1021 // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
1022 if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
1023 return false;
1024 }
1025 if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
1026 return false;
1027 }
1028 if (!IsOnTheSameDevice(group, node)) {
1029 return false;
1030 }
1031 // Optimized nodes updated in place, and that would break the graph, if the
1032 // node has multiple output consumers
1033 if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
1034 return false;
1035 }
1036 // All input shapes must be broadcastable to the node shape
1037 const OpInfo::TensorProperties* properties;
1038 Status has_properties = GetTensorProperties(node.name(), &properties);
1039 return has_properties.ok() &&
1040 HasAllInputsBroadcastableToShape(node, *properties);
1041 }
1042
CountUniqueShapes(const std::vector<InputAndShape> & inputs)1043 std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
1044 std::set<string> sigs;
1045 for (const auto& ias : inputs) {
1046 sigs.insert(ShapeSignature(ias.shape));
1047 }
1048 return sigs.size();
1049 }
1050
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)1051 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
1052 VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
1053 << " op=" << group.root_node->op()
1054 << " num_optimized_nodes=" << group.optimized_nodes.size();
1055
1056 // Do not optimize any of the nodes that are part of this group.
1057 MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
1058
1059 if (CountUniqueShapes(group.inputs) <= 1) {
1060 VLOG(3) << "Skip min-bcast group with single unique shape";
1061 // nothing to optimize when all shapes are the same
1062 return group.root_node->name();
1063 }
1064
1065 auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1066 auto num_inputs = group.inputs.size();
1067 CHECK_EQ(num_nodes, num_inputs - 1)
1068 << "Can't build a tree with " << num_inputs << " inputs, using "
1069 << num_nodes << "binary op nodes.";
1070
1071 std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1072 std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1073 group.optimized_nodes.end());
1074
1075 // sort inputs by it's shape from smallest to largest
1076 std::stable_sort(add_ops.begin(), add_ops.end(),
1077 [](const InputAndShape& lhs, const InputAndShape& rhs) {
1078 return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1079 rhs.shape);
1080 });
1081
1082 // If there is an odd number of inputs, last one is the largest, and we want
1083 // to attach it to the root node, to build a well balanced tree.
1084 std::deque<InputAndShape> add_ops_leftover;
1085 if (add_ops.size() % 2 != 0) {
1086 add_ops_leftover.push_back(add_ops.back());
1087 add_ops.pop_back();
1088 }
1089
1090 // At this point it's guaranteed that add_ops have even number of inputs.
1091 do {
1092 const InputAndShape lhs = add_ops.front();
1093 add_ops.pop_front();
1094 const InputAndShape rhs = add_ops.front();
1095 add_ops.pop_front();
1096
1097 NodeDef* node;
1098 if (!optimized_nodes.empty()) {
1099 // re-purpose optimized nodes to build a new tree
1100 node = optimized_nodes.back();
1101 optimized_nodes.pop_back();
1102 } else {
1103 // or use root node if none optimized nodes left
1104 node = group.root_node;
1105 }
1106 InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1107
1108 // Pushing updated node to the back of a deque will create a wide and
1109 // short tree, pushing to the front will create a tall tree. We prefer to
1110 // get a wide tree, it minimizes the potential number of temporary tensors
1111 // required to keep in memory, though sometimes we can go up to prevent
1112 // propagating a broadcast from leaves to the root. Example:
1113 //
1114 // inputs: [s, s, s, M] (s - scalar, M - matrix)
1115 // @* - op with broadcast
1116 //
1117 // (only push_back) @* (push_front first op)
1118 // / \
1119 // @* @ M
1120 // / \ / \
1121 // @ @* -> @ s
1122 // / \ / \ / \
1123 // s s s M s s
1124 if (add_ops.size() >= 2 &&
1125 CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1126 add_ops.at(1).shape)) {
1127 add_ops.push_front(updated_node);
1128 } else {
1129 add_ops.push_back(updated_node);
1130 }
1131 } while (add_ops.size() > 1);
1132 CHECK_EQ(1, add_ops.size());
1133
1134 // attach the largest tensor to the root op
1135 if (!add_ops_leftover.empty()) {
1136 const InputAndShape lhs = add_ops.front();
1137 add_ops.pop_front();
1138 const InputAndShape rhs = add_ops_leftover.front();
1139 InputAndShape updated_node =
1140 UpdateInputs(lhs.input, rhs.input, group.root_node);
1141 add_ops.push_back(updated_node);
1142 }
1143
1144 return add_ops.front().input;
1145 }
1146
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1147 InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1148 NodeDef* node) {
1149 string old_input_0 = node->input(0);
1150 string old_input_1 = node->input(1);
1151
1152 // Update inputs only if they changed
1153 if (old_input_0 != input_0 || old_input_1 != input_1) {
1154 node->set_input(0, input_0);
1155 node->set_input(1, input_1);
1156 // Invalidate node properties (shape)
1157 ctx().graph_properties->ClearOutputProperties(node->name());
1158 ctx().graph_properties->ClearInputProperties(node->name());
1159 // Update the node map
1160 ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1161 ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1162 ctx().node_map->AddOutput(NodeName(input_0), node->name());
1163 ctx().node_map->AddOutput(NodeName(input_1), node->name());
1164 // Add updated node to optimization queue
1165 AddToOptimizationQueue(node);
1166 }
1167
1168 TensorShapeProto shape; // shape is not important at this point
1169 return InputAndShape(node->name(), shape);
1170 }
1171 };
1172
1173 // Removes inverse transpose nodes
1174 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1175 public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1176 explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1177 const ArithmeticOptimizerContext& ctx_ext)
1178 : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1179 ~RemoveIdentityTranspose() override = default;
1180
IsSupported(const NodeDef * node) const1181 bool IsSupported(const NodeDef* node) const override {
1182 return IsTranspose(*node) || IsConjugateTranspose(*node);
1183 }
1184
TrySimplify(NodeDef * node,string * simplified_node_name)1185 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1186 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1187 NodeDef* tail = node;
1188 tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1189 *ctx().nodes_to_preserve);
1190 NodeDef* first_transpose;
1191 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1192
1193 NodeDef* node_perm;
1194 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1195 if (!IsConstant(*node_perm)) {
1196 return Status::OK();
1197 }
1198 std::vector<int64> node_perm_values;
1199 TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1200 if (first_transpose->op() == node->op()) {
1201 // Remove pairs of transposes that cancel each other.
1202 NodeDef* first_transpose_perm;
1203 TF_RETURN_IF_ERROR(
1204 GetInputNode(first_transpose->input(1), &first_transpose_perm));
1205 if (!IsConstant(*first_transpose_perm)) {
1206 return Status::OK();
1207 }
1208 std::vector<int64> first_transpose_perm_values;
1209 TF_RETURN_IF_ERROR(
1210 GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1211 if (AreInversePermutations(node_perm_values,
1212 first_transpose_perm_values)) {
1213 if (tail == node) {
1214 // Bypass adjacent pair.
1215 *simplified_node_name = first_transpose->input(0);
1216 } else {
1217 // Bypass pair connected through chain.
1218 tail->set_input(0, first_transpose->input(0));
1219 ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1220 first_transpose->input(0));
1221 ForwardControlDependencies(tail, {first_transpose});
1222 *simplified_node_name = node->input(0);
1223 }
1224 }
1225 } else {
1226 // Remove simple identity transposes.
1227 if (IsIdentityPermutation(node_perm_values)) {
1228 if (IsConjugateTranspose(*node)) {
1229 const NodeScopeAndName transpose =
1230 ParseNodeScopeAndName(node->name());
1231 const string optimized_node_name = OptimizedNodeName(transpose);
1232 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1233 new_op->set_op("Conj");
1234 new_op->mutable_input()->RemoveLast();
1235 new_op->mutable_attr()->erase("Tperm");
1236 ForwardControlDependencies(new_op, {node});
1237 *simplified_node_name = new_op->name();
1238 } else {
1239 *simplified_node_name = node->input(0);
1240 }
1241 }
1242 }
1243 return Status::OK();
1244 }
1245
1246 private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1247 Status GetPermutation(const NodeDef& node_perm,
1248 std::vector<int64>* perm64) const {
1249 std::vector<int> perm32;
1250 if (ValuesFromConstNode(node_perm, &perm32)) {
1251 perm64->reserve(perm32.size());
1252 for (int val : perm32) {
1253 perm64->push_back(static_cast<int64>(val));
1254 }
1255 return Status::OK();
1256 }
1257 if (ValuesFromConstNode(node_perm, perm64)) {
1258 return Status::OK();
1259 }
1260 return errors::InvalidArgument("Couldn't extract permutation from ",
1261 node_perm.name());
1262 }
1263
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1264 bool AreInversePermutations(const std::vector<int64>& a,
1265 const std::vector<int64>& b) {
1266 if (a.size() != b.size()) {
1267 return false;
1268 }
1269 for (int i = 0, end = a.size(); i < end; ++i) {
1270 if (a[b[i]] != i) {
1271 return false;
1272 }
1273 }
1274 return true;
1275 }
1276
IsIdentityPermutation(const std::vector<int64> & perm)1277 bool IsIdentityPermutation(const std::vector<int64>& perm) {
1278 for (int64_t i = 0, end = perm.size(); i < end; ++i) {
1279 if (i != perm[i]) {
1280 return false;
1281 }
1282 }
1283 return true;
1284 }
1285 };
1286
1287 // An involution is an element-wise function f(x) that is its own inverse,
1288 // i.e. f(f(x)) = x. If we can find a chain of ops
1289 // f->op1->op2->...opn->f
1290 // where op1 through opn preserve the values of their inputs, we can remove
1291 // the two instances of the involution from the graph, since they cancel
1292 // each other.
1293 class RemoveInvolution : public ArithmeticOptimizerStage {
1294 public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1295 explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1296 const ArithmeticOptimizerContext& ctx_ext)
1297 : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1298 ~RemoveInvolution() override = default;
1299
IsSupported(const NodeDef * node) const1300 bool IsSupported(const NodeDef* node) const override {
1301 return IsInvolution(*node);
1302 }
1303
TrySimplify(NodeDef * node,string * simplified_node_name)1304 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1305 NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1306 *ctx().nodes_to_preserve);
1307
1308 NodeDef* involution;
1309 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1310
1311 if (involution->op() == node->op()) {
1312 // Skip both *node and *involution since they cancel each other.
1313 if (tail == node) {
1314 // The two nodes to eliminate are adjacent.
1315 *simplified_node_name = involution->input(0);
1316 } else {
1317 tail->set_input(0, involution->input(0));
1318 ctx().node_map->UpdateInput(tail->name(), involution->name(),
1319 involution->input(0));
1320 *simplified_node_name = node->input(0);
1321 }
1322 }
1323
1324 return Status::OK();
1325 }
1326 };
1327
1328 // Remove redundant Bitcasts.
1329 // 1) Remove Bitcast whose source type and destination type are equal
1330 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1331 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1332 public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1333 explicit RemoveRedundantBitcastStage(
1334 const GraphOptimizerContext& ctx,
1335 const ArithmeticOptimizerContext& ctx_ext)
1336 : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1337 ~RemoveRedundantBitcastStage() override = default;
1338
IsSupported(const NodeDef * node) const1339 bool IsSupported(const NodeDef* node) const override {
1340 return IsBitcast(*node);
1341 }
1342
TrySimplify(NodeDef * node,string * simplified_node_name)1343 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1344 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1345
1346 // Bypass Bitcast whose source type and destination type are equal.
1347 AttrSlice attrs(*node);
1348 DataType input_type;
1349 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1350 DataType output_type;
1351 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1352 if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1353 *simplified_node_name = node->input(0);
1354 return Status::OK();
1355 }
1356
1357 NodeDef* bitcast;
1358 TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1359 NodeDef* operand;
1360 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1361
1362 if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1363 AttrSlice operand_attrs(*operand);
1364 DataType operand_input_type;
1365 TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1366 // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1367 bitcast->set_input(0, operand->input(0));
1368 SetDataTypeToAttr(operand_input_type, "T", bitcast);
1369 ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1370 operand->input(0));
1371 AddToOptimizationQueue(bitcast);
1372 *simplified_node_name = bitcast->name();
1373 }
1374
1375 return Status::OK();
1376 }
1377 };
1378
1379 // Remove Casts whose source type and destination type are equal.
1380 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1381 public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1382 explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1383 const ArithmeticOptimizerContext& ctx_ext)
1384 : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1385 ~RemoveRedundantCastStage() override = default;
1386
IsSupported(const NodeDef * node) const1387 bool IsSupported(const NodeDef* node) const override {
1388 return IsCast(*node) && !IsInPreserveSet(*node);
1389 }
1390
TrySimplify(NodeDef * node,string * simplified_node_name)1391 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1392 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1393
1394 // Bypass Cast whose source type and destination type are equal.
1395 AttrSlice attrs(*node);
1396 DataType input_type;
1397 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1398 DataType output_type;
1399 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1400 if (input_type == output_type) {
1401 *simplified_node_name = node->input(0);
1402 }
1403 return Status::OK();
1404 }
1405 };
1406
1407 class RemoveNegationStage : public ArithmeticOptimizerStage {
1408 public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1409 explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1410 const ArithmeticOptimizerContext& ctx_ext)
1411 : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1412 ~RemoveNegationStage() override = default;
1413
IsSupported(const NodeDef * node) const1414 bool IsSupported(const NodeDef* node) const override {
1415 return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1416 }
1417
TrySimplify(NodeDef * node,string * simplified_node_name)1418 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1419 NodeDef* x;
1420 NodeDef* y;
1421 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1422 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1423 bool updated = false;
1424 if (IsNeg(*y)) {
1425 // a - (-b) = a + b or a + (-b) = a - b
1426 ForwardControlDependencies(node, {y});
1427 ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1428 node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1429 node->set_input(1, y->input(0));
1430 updated = true;
1431 } else if (IsAdd(*node) && IsNeg(*x)) {
1432 // (-a) + b = b - a
1433 ForwardControlDependencies(node, {x});
1434 ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1435 node->set_op("Sub");
1436 node->mutable_input()->SwapElements(0, 1);
1437 node->set_input(1, x->input(0));
1438 updated = true;
1439 }
1440 if (updated) {
1441 AddToOptimizationQueue(node);
1442 }
1443 return Status::OK();
1444 }
1445 };
1446
1447 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1448 public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1449 explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1450 const ArithmeticOptimizerContext& ctx_ext)
1451 : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1452 ~RemoveLogicalNotStage() override = default;
1453
IsSupported(const NodeDef * node) const1454 bool IsSupported(const NodeDef* node) const override {
1455 return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1456 }
1457
TrySimplify(NodeDef * node,string * simplified_node_name)1458 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1459 const string node_name = node->name();
1460 NodeDef* input;
1461 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1462 if (IsInPreserveSet(*input) ||
1463 NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1464 return Status::OK();
1465 }
1466 string new_op;
1467 if (IsEqual(*input)) {
1468 new_op = "NotEqual";
1469 } else if (IsNotEqual(*input)) {
1470 new_op = "Equal";
1471 } else if (IsLess(*input)) {
1472 new_op = "GreaterEqual";
1473 } else if (IsLessEqual(*input)) {
1474 new_op = "Greater";
1475 } else if (IsGreater(*input)) {
1476 new_op = "LessEqual";
1477 } else if (IsGreaterEqual(*input)) {
1478 new_op = "Less";
1479 }
1480 if (!new_op.empty()) {
1481 input->set_op(new_op);
1482 *simplified_node_name = input->name();
1483 }
1484 return Status::OK();
1485 }
1486 };
1487
1488 // This optimization hoists the common prefix of unary ops of the inputs to
1489 // concat out of the concat, for example:
1490 // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1491 // becomes
1492 // Exp(Sin(Concat([x, y, z]))).
1493 // Similarly, it will hoist the common postfix of unary ops into Split or
1494 // SplitV nodes, for example:
1495 // [Exp(Sin(y)) for y in Split(x)]
1496 // becomes
1497 // [y for y in Split(Exp(Sin(x))]
1498 //
1499 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1500 // on the concat/split node.
1501 // TODO(rmlarsen): Handle Enter/Exit.
1502 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1503 public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1504 explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1505 const ArithmeticOptimizerContext& ctx_ext)
1506 : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1507
1508 ~HoistCWiseUnaryChainsStage() override = default;
1509
1510 struct ChainLink {
1511 ChainLink() = default;
ChainLinktensorflow::grappler::__anon912b57ca0111::HoistCWiseUnaryChainsStage::ChainLink1512 ChainLink(NodeDef* _node, int _port_origin)
1513 : node(_node), port_origin(_port_origin) {}
1514 NodeDef* node; // Node in a chain.
1515 int port_origin; // Port on concat/split node from which this chain
1516 // originates.
1517
operator <tensorflow::grappler::__anon912b57ca0111::HoistCWiseUnaryChainsStage::ChainLink1518 bool operator<(const ChainLink& other) const {
1519 if (port_origin < other.port_origin) {
1520 return true;
1521 } else if (port_origin > other.port_origin) {
1522 return false;
1523 } else {
1524 return node->name() < other.node->name();
1525 }
1526 }
1527 };
1528
1529 // We use an ordinary set sorted on port and node name, so the order, and
1530 // hence the node name used for the hoisted chain, will be deterministic.
1531 using ChainLinkSet = std::set<ChainLink>;
1532
IsSupported(const NodeDef * node) const1533 bool IsSupported(const NodeDef* node) const override {
1534 if (IsInPreserveSet(*node)) return false;
1535 if (IsConcat(*node) && node->attr().count("N") != 0) {
1536 const int n = node->attr().at("N").i();
1537 return n > 1 && FirstNInputsAreUnique(*node, n);
1538 } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1539 node->attr().count("num_split") != 0) {
1540 const int num_split = node->attr().at("num_split").i();
1541 if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1542 // TODO(rmlarsen): Remove this constraint when we have optimizations
1543 // in place for merging slices into splits.
1544 return false;
1545 }
1546 if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1547 // TODO(ezhulenev): Unary ops after Split might have a control path to
1548 // the Split node, and we currently do not properly handle cycles.
1549 return false;
1550 }
1551 return num_split > 1 && !IsAlreadyOptimized(*node);
1552 }
1553 return false;
1554 }
1555
TrySimplify(NodeDef * node,string * simplified_node_name)1556 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1557 node_is_concat_ = IsConcat(*node);
1558 int prefix_length;
1559 std::set<string> ctrl_inputs;
1560 ChainLinkSet tails;
1561 TF_RETURN_IF_ERROR(
1562 FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1563 if (prefix_length > 0 && !tails.empty()) {
1564 TF_RETURN_IF_ERROR(
1565 HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1566 }
1567 return Status::OK();
1568 }
1569
1570 private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1571 bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1572 if (n > node.input_size()) return false;
1573 absl::flat_hash_set<string> unique_inputs;
1574 const int start = node.op() == "Concat" ? 1 : 0;
1575 const int end = start + n;
1576 for (int i = start; i < end; ++i) {
1577 unique_inputs.insert(node.input(i));
1578 }
1579 int unique_input_size = unique_inputs.size();
1580 return unique_input_size == n;
1581 }
1582
1583 // Returns the length of the common unary chain of ops that can be
1584 // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1585 Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1586 ChainLinkSet* tails,
1587 std::set<string>* ctrl_inputs) const {
1588 *prefix_length = 0;
1589 // Follow the chains starting at each concat input or split output as long
1590 // as all the following conditions hold:
1591 // 1. The ops in all chains are the same.
1592 // 2. The ops are unary elementwise op.
1593 // 3. The op output has only a single consumer (concat only).
1594 ChainLinkSet cur_tails;
1595 TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1596 if (cur_tails.size() < 2) {
1597 return Status::OK();
1598 }
1599 ctrl_inputs->clear();
1600 bool stop = false;
1601 while (!stop && !cur_tails.empty() &&
1602 OpsAreSafeToHoist(root_node, cur_tails)) {
1603 // We found one more link that can be hoisted.
1604 ++(*prefix_length);
1605 tails->swap(cur_tails);
1606 GatherControlInputs(ctrl_inputs, *tails);
1607
1608 // Advance tail pointers to the next level.
1609 TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1610 }
1611 return Status::OK();
1612 }
1613
1614 // Hoists the chains to the other side of concat or split and attaches the
1615 // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1616 Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1617 std::set<string>* ctrl_inputs, NodeDef* root_node) {
1618 VLOG(3) << "Hoist unary op chain:"
1619 << " root=" << root_node->DebugString()
1620 << " prefix_length=" << prefix_length << " ctrl_inputs=["
1621 << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1622
1623 if (tails.empty()) {
1624 return Status::OK();
1625 }
1626 AddToOptimizationQueue(root_node);
1627 optimized_nodes_.insert(root_node->name());
1628 if (node_is_concat_) {
1629 AddControlInputs(ctrl_inputs, root_node);
1630 return HoistChainForConcat(prefix_length, tails, root_node);
1631 } else {
1632 return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1633 }
1634 }
1635
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1636 void GatherControlInputs(std::set<string>* ctrl_inputs,
1637 const ChainLinkSet& ops) const {
1638 for (const auto& link : ops) {
1639 const NodeDef* node = link.node;
1640 for (int i = node->input_size() - 1; i >= 0; --i) {
1641 const string& input = node->input(i);
1642 if (!IsControlInput(input)) break;
1643 ctrl_inputs->insert(input);
1644 }
1645 }
1646 }
1647
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1648 void AddControlInputs(std::set<string>* new_ctrl_inputs,
1649 NodeDef* node) const {
1650 for (int i = node->input_size() - 1; i >= 0; --i) {
1651 const string& existing_input = node->input(i);
1652 if (!IsControlInput(existing_input)) break;
1653 new_ctrl_inputs->erase(existing_input);
1654 }
1655 for (const string& new_input : *new_ctrl_inputs) {
1656 ctx().node_map->AddOutput(NodeName(new_input), node->name());
1657 node->add_input(new_input);
1658 }
1659 }
1660
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1661 Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1662 if (node_is_concat_) {
1663 // Handle concat nodes by looking backwards in the graph.
1664 TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1665 const int n = node.attr().at("N").i();
1666 const int start = node.op() == "Concat" ? 1 : 0;
1667 const int end = start + n;
1668 if (end > node.input_size()) {
1669 return errors::FailedPrecondition("Got attr N=", n,
1670 " without enough inputs.");
1671 }
1672 // Set up tail pointers to point to the immediate inputs to Concat.
1673 for (int input_port = start; input_port < end; ++input_port) {
1674 if (IsControlInput(node.input(input_port))) {
1675 return errors::FailedPrecondition(
1676 "Got control input ", node.input(input_port),
1677 " where normal input was expected.");
1678 }
1679 NodeDef* tail;
1680 TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1681 tails->insert(ChainLink(tail, input_port));
1682 }
1683 return Status::OK();
1684 } else {
1685 // Handle split nodes by looking forwards in the graph.
1686 const auto& outputs = ctx().node_map->GetOutputs(node.name());
1687 for (NodeDef* output : outputs) {
1688 if (output->input_size() == 0 || IsControlInput(output->input(0))) {
1689 continue;
1690 }
1691 TensorId tensor_id = ParseTensorName(output->input(0));
1692 if (tensor_id.node() == node.name()) {
1693 tails->insert(ChainLink(output, tensor_id.index()));
1694 } else {
1695 // This output node has a non-control input other than the split node,
1696 // abort.
1697 tails->clear();
1698 return Status::OK();
1699 }
1700 }
1701 }
1702 return Status::OK();
1703 }
1704
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1705 bool OpsAreSafeToHoist(const NodeDef& root_node,
1706 const ChainLinkSet& ops) const {
1707 if (ops.empty()) return true;
1708 const NodeDef* op0 = ops.begin()->node;
1709 if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1710 for (const auto& link : ops) {
1711 const NodeDef* op = link.node;
1712 if (op->device() != root_node.device() || op->op() != op0->op() ||
1713 IsInPreserveSet(*op)) {
1714 return false;
1715 }
1716 if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1717 // TODO(rmlarsen): Allow outgoing control edges.
1718 return false;
1719 }
1720 // Do not hoist Relu if it can be fused with its predecessors. This is
1721 // important because remapping runs after arithmetic.
1722 if (IsRelu(*op) || IsRelu6(*op)) {
1723 NodeDef* operand = nullptr;
1724 if (!GetInputNode(op->input(0), &operand).ok()) {
1725 return false;
1726 }
1727 if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1728 return false;
1729 }
1730 }
1731 }
1732 return true;
1733 }
1734
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1735 Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1736 bool* stop) const {
1737 *stop = true;
1738 new_tails->clear();
1739 for (const auto& link : tails) {
1740 const NodeDef* tail = link.node;
1741 if (node_is_concat_) {
1742 if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1743 return Status::OK();
1744 }
1745 NodeDef* new_tail;
1746 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1747 // Remember original port.
1748 new_tails->insert(ChainLink(new_tail, link.port_origin));
1749 } else {
1750 for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1751 const TensorId tensor = ParseTensorName(new_tail->input(0));
1752 if (tensor.node() != tail->name()) {
1753 return Status::OK();
1754 }
1755 // Skip control outputs.
1756 if (tensor.index() >= 0) {
1757 // Remember original port.
1758 new_tails->insert(ChainLink(new_tail, link.port_origin));
1759 }
1760 }
1761 }
1762 }
1763 *stop = false;
1764 return Status::OK();
1765 }
1766
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1767 Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1768 NodeDef* concat_node) {
1769 const string& concat_name = concat_node->name();
1770 const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1771 for (const auto& link : tails) {
1772 NodeDef* tail = CHECK_NOTNULL(link.node);
1773 const int concat_port = link.port_origin;
1774 CHECK_GE(concat_port, 0);
1775 CHECK_LT(concat_port, concat_node->input_size());
1776 const string concat_input = concat_node->input(concat_port);
1777 // Hook the node following tail directly into the concat node.
1778 const string tail_input = tail->input(0);
1779 concat_node->set_input(concat_port, tail_input);
1780 ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1781
1782 if (concat_port == first_input) {
1783 // Update the consumers of concat to consume the end of the chain
1784 // instead.
1785 TF_RETURN_IF_ERROR(UpdateConsumers(concat_node, concat_input));
1786 // Reuse nodes in the first chain to process output of concat.
1787 tail->set_input(0, concat_name);
1788 ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1789 }
1790 }
1791 return Status::OK();
1792 }
1793
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1794 Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1795 std::set<string>* ctrl_inputs,
1796 NodeDef* split_node) {
1797 // Create a new chain before the split node to process the input tensor.
1798 const string& split_name = split_node->name();
1799 auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1800
1801 // We use the first tail node in the set as a template to get the list of
1802 // ops to apply (starting from the end).
1803 NodeDef* cur_tail = tails.begin()->node;
1804 NodeDef* cur_copy = AddCopyNode(
1805 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1806 cur_copy->clear_input();
1807
1808 // Update the split to take its input from the tail of the new chain.
1809 const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1810 const string orig_input = split_node->input(value_slot);
1811 split_node->set_input(value_slot, cur_copy->name());
1812 ctx().node_map->UpdateInput(split_node->name(), orig_input,
1813 cur_copy->name());
1814 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1815
1816 // Now walk backwards creating the rest of the chain.
1817 while (cur_tail != split_node) {
1818 NodeDef* new_copy = AddCopyNode(
1819 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1820 new_copy->clear_input();
1821 cur_copy->add_input(new_copy->name());
1822 ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1823 cur_copy = new_copy;
1824 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1825 }
1826 // Connect the original input to the head of the new chain.
1827 cur_copy->add_input(orig_input);
1828 ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1829 cur_copy->name());
1830 // Make sure all the control inputs are satisfied before running the first
1831 // node in the new chain.
1832 AddControlInputs(ctrl_inputs, cur_copy);
1833
1834 // Connect all consumers of the tail nodes directly to the
1835 // output port of Split from which the chain started.
1836 for (const auto& link : tails) {
1837 TF_RETURN_IF_ERROR(UpdateConsumers(
1838 link.node, link.port_origin == 0
1839 ? split_name
1840 : strings::StrCat(split_name, ":", link.port_origin)));
1841 }
1842 return Status::OK();
1843 }
1844
IsAlreadyOptimized(const NodeDef & node) const1845 bool IsAlreadyOptimized(const NodeDef& node) const {
1846 return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1847 }
1848
1849 private:
1850 bool node_is_concat_;
1851 std::unordered_set<string> optimized_nodes_;
1852 };
1853
1854 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1855 public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1856 explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1857 const ArithmeticOptimizerContext& ctx_ext)
1858 : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1859 ~RemoveIdempotentStage() override = default;
1860
IsSupported(const NodeDef * node) const1861 bool IsSupported(const NodeDef* node) const override {
1862 return node->input_size() == 1 && IsIdempotent(*node) &&
1863 !IsInPreserveSet(*node);
1864 }
1865
TrySimplify(NodeDef * node,string * simplified_node_name)1866 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1867 NodeDef* input;
1868 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1869 if (input->op() == node->op() && input->device() == node->device()) {
1870 *simplified_node_name = node->input(0);
1871 }
1872 return Status::OK();
1873 }
1874 };
1875
1876 // Performs the conversion:
1877 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1878 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1879 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1880 public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1881 explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1882 const ArithmeticOptimizerContext& ctx_ext)
1883 : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1884 ~SqrtDivToRsqrtMulStage() override = default;
1885
IsSupported(const NodeDef * node) const1886 bool IsSupported(const NodeDef* node) const override {
1887 // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1888 // for b == 0 would result in a / Inf instead of 0.
1889 return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1890 }
1891
TrySimplify(NodeDef * node,string * simplified_node_name)1892 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1893 NodeDef* y;
1894 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1895 // Optimize only if divisor is a Sqrt whose output is not being consumed
1896 // elsewhere.
1897 if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1898 (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1899 if (IsXdivy(*node)) {
1900 // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1901 node->set_op("MulNoNan");
1902 node->mutable_input()->SwapElements(0, 1);
1903 } else {
1904 // div(a, sqrt(b)) => mul(a, rsqrt(b))
1905 node->set_op("Mul");
1906 }
1907 y->set_op("Rsqrt");
1908 AddToOptimizationQueue(node);
1909 AddToOptimizationQueue(y);
1910 }
1911 return Status::OK();
1912 }
1913 };
1914
1915 // Performs the following conversion for real types:
1916 // Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1917 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1918 public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1919 explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1920 const ArithmeticOptimizerContext& ctx_ext)
1921 : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1922 ~FuseSquaredDiffStage() override = default;
1923
IsSupported(const NodeDef * node) const1924 bool IsSupported(const NodeDef* node) const override {
1925 return IsSquare(*node);
1926 }
1927
TrySimplify(NodeDef * node,string * simplified_node_name)1928 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1929 NodeDef* b;
1930 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1931 // Optimize only if base is a Sub whose output is not being consumed
1932 // elsewhere.
1933 if (IsSub(*b) && !IsInPreserveSet(*b) &&
1934 (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1935 // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1936 // invalid.
1937 const DataType type = GetDataTypeFromAttr(*b, "T");
1938 if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128))
1939 return Status::OK();
1940 node->set_op("Identity");
1941 b->set_op("SquaredDifference");
1942 AddToOptimizationQueue(node);
1943 AddToOptimizationQueue(b);
1944 }
1945 return Status::OK();
1946 }
1947 };
1948
1949 // Performs the conversion:
1950 // Log(Softmax(x)) => LogSoftmax(x)
1951 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1952 public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1953 explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1954 const ArithmeticOptimizerContext& ctx_ext)
1955 : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1956 ~LogSoftmaxStage() override = default;
1957
IsSupported(const NodeDef * node) const1958 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1959
TrySimplify(NodeDef * node,string * simplified_node_name)1960 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1961 NodeDef* x;
1962 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1963 // Optimize only if arg is a Softmax whose output is not being consumed
1964 // elsewhere.
1965 if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1966 (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1967 // Log(Softmax(x)) => LogSoftmax(Identity(x))
1968 node->set_op("LogSoftmax");
1969 x->set_op("Identity");
1970 AddToOptimizationQueue(node);
1971 AddToOptimizationQueue(x);
1972 }
1973 return Status::OK();
1974 }
1975 };
1976
1977 // Bypass redundant reshape nodes:
1978 //
1979 // Reshape Reshape <-+
1980 // ^ |
1981 // | |
1982 // Reshape becomes Reshape |
1983 // ^ |
1984 // | |
1985 // input input ---+
1986 //
1987 // Additionally, Reshape and BroadcastTo nodes where the
1988 // input and target shapes are equal are bypassed.
1989 //
1990 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1991 public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1992 explicit RemoveRedundantReshapeOrBroadcastTo(
1993 const GraphOptimizerContext& ctx,
1994 const ArithmeticOptimizerContext& ctx_ext)
1995 : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1996 ctx_ext) {}
1997 ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1998
IsSupported(const NodeDef * node) const1999 bool IsSupported(const NodeDef* node) const override {
2000 return IsReshape(*node) || IsBroadcastTo(*node);
2001 }
2002
2003 // TODO(rmlarsen): Handle unary ops with multiple outputs.
TrySimplify(NodeDef * node,string * simplified_node_name)2004 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2005 // 1. If the reshape is a no-op, forward its input to its consumers, unless
2006 // it anchors a control dependency since we want to make sure that control
2007 // dependency is triggered.
2008 if (!IsInPreserveSet(*node) && InputMatchesTargetShape(*node) &&
2009 !HasControlInputs(*node)) {
2010 *simplified_node_name = node->input(0);
2011 return Status::OK();
2012 }
2013
2014 // 2. Bypass reshape followed by reshape, possibly separated by a simple
2015 // chain of unary elementwise ops that are not outputs.
2016 if (IsReshape(*node)) {
2017 bool skip = false;
2018 gtl::InlinedVector<const NodeDef*, 4> nodes_in_chain;
2019 const auto predicate_fn = [this, node, &skip,
2020 &nodes_in_chain](const NodeDef& input) {
2021 nodes_in_chain.push_back(&input);
2022 if ((input.name() != node->name() &&
2023 NumNonControlOutputs(input, *ctx().node_map) > 1) ||
2024 IsInPreserveSet(input) || ModifiesFrameInfo(input)) {
2025 skip = true;
2026 return false;
2027 }
2028 return IsUnaryElementWise(input);
2029 };
2030
2031 // Walk up the input chain until we find a node that is not unary
2032 // element-wise. If it is another Reshape node, we can bypass it.
2033 NodeDef* tail =
2034 GetTailOfChain(*node, *ctx().node_map,
2035 /*follow_control_input*/ false, predicate_fn);
2036
2037 if (!skip && tail != nullptr && !IsInPreserveSet(*tail)) {
2038 NodeDef* reshape_to_bypass;
2039 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &reshape_to_bypass));
2040 if (reshape_to_bypass == nullptr ||
2041 (!IsReshape(*reshape_to_bypass) ||
2042 NumNonControlOutputs(*reshape_to_bypass, *ctx().node_map) > 1 ||
2043 IsInPreserveSet(*reshape_to_bypass))) {
2044 return Status::OK();
2045 }
2046 // Clearing invalid shape inference results of nodes in chain.
2047 for (const NodeDef* node_in_chain : nodes_in_chain) {
2048 ctx().graph_properties->ClearInputProperties(node_in_chain->name());
2049 if (node_in_chain != node) {
2050 ctx().graph_properties->ClearOutputProperties(
2051 node_in_chain->name());
2052 }
2053 }
2054 // We now have
2055 // reshape_to_bypass -> tail -> ... -> node
2056 // where tail maybe equal to node.
2057 TF_RETURN_IF_ERROR(
2058 UpdateConsumers(reshape_to_bypass, reshape_to_bypass->input(0)));
2059 ForwardControlDependencies(tail, {reshape_to_bypass});
2060 // Change the bypassed reshape to NoOp.
2061 ReplaceWithNoOp(reshape_to_bypass, ctx());
2062 *simplified_node_name = node->name();
2063 return Status::OK();
2064 }
2065 }
2066
2067 return Status::OK();
2068 }
2069
2070 private:
2071 // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)2072 bool InputMatchesTargetShape(const NodeDef& reshape) {
2073 const OpInfo::TensorProperties* reshape_props;
2074 const OpInfo::TensorProperties* input_props;
2075 if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
2076 !GetTensorProperties(reshape.input(0), &input_props).ok()) {
2077 return false;
2078 }
2079
2080 return ShapesSymbolicallyEqual(input_props->shape(),
2081 reshape_props->shape());
2082 }
2083 };
2084
2085 // Reorder casting and value-preserving ops if beneficial.
2086 //
2087 // Original motivation: A common pattern after the layout optimizer is
2088 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2089 // is beneficial to reorder the cast and the transpose to make the transpose
2090 // process smaller amount of data. More generally, this optimization converts
2091 // Op(Cast(tensor, dst_type))
2092 // to
2093 // Cast(Op(tensor), dst_type)
2094 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2095 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2096 // this optimization converts
2097 // Cast(Op(tensor), dst_type)
2098 // to
2099 // Op(Cast(tensor, dst_type))
2100 // when sizeof(tensor.type) > sizeof(dst_type)
2101 //
2102 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2103 public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2104 explicit ReorderCastLikeAndValuePreserving(
2105 const GraphOptimizerContext& ctx,
2106 const ArithmeticOptimizerContext& ctx_ext)
2107 : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2108 ctx_ext) {}
2109 ~ReorderCastLikeAndValuePreserving() override = default;
2110
IsSupported(const NodeDef * node) const2111 bool IsSupported(const NodeDef* node) const override {
2112 return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2113 !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2114 !IsControlFlow(*node) && !IsInPreserveSet(*node);
2115 }
2116
TrySimplify(NodeDef * consumer,string * simplified_node_name)2117 Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2118 NodeDef* producer;
2119
2120 if (consumer->input_size() < 1) {
2121 return errors::FailedPrecondition("Node ", simplified_node_name,
2122 " lacks inputs");
2123 }
2124
2125 TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2126 const bool producer_is_cast = IsCastLike(*producer);
2127 const bool can_optimize =
2128 !IsCheckNumerics(*producer) &&
2129 ((producer_is_cast && IsValuePreserving(*consumer)) ||
2130 (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2131 if (!can_optimize || IsControlFlow(*producer) ||
2132 IsInPreserveSet(*producer) ||
2133 producer->device() != consumer->device()) {
2134 return Status::OK();
2135 }
2136
2137 const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2138 const OpDef* cast_like_op_def = nullptr;
2139 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2140 &cast_like_op_def));
2141 DataType cast_src_type;
2142 TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2143 &cast_src_type));
2144 DataType cast_dst_type;
2145 TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2146 &cast_dst_type));
2147 if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2148 return Status::OK();
2149 } else if (producer_is_cast &&
2150 DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2151 return Status::OK();
2152 } else if (!producer_is_cast &&
2153 DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2154 return Status::OK();
2155 }
2156
2157 // Check that nodes were not already optimized.
2158 const string optimized_producer_name = OptimizedNodeName(
2159 ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2160 const string optimized_consumer_name = OptimizedNodeName(
2161 ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2162 const bool is_already_optimized =
2163 ctx().node_map->NodeExists(optimized_consumer_name) ||
2164 ctx().node_map->NodeExists(optimized_producer_name);
2165 if (is_already_optimized) {
2166 return Status::OK();
2167 }
2168
2169 // Add copies of consumer and producer in reverse order.
2170 NodeDef* input;
2171 TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2172 // Create new producer node.
2173 NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2174 new_producer->set_input(0, producer->input(0));
2175 ctx().node_map->AddOutput(input->name(), new_producer->name());
2176
2177 // Create new consumer node.
2178 NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2179 new_consumer->set_input(0, new_producer->name());
2180
2181 NodeDef* new_value_preserving =
2182 producer_is_cast ? new_producer : new_consumer;
2183 const DataType new_input_type =
2184 producer_is_cast ? cast_src_type : cast_dst_type;
2185 // Update the input type of the value-preserving node. The input and
2186 // output types of the cast-like nodes remain the same.
2187 TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2188 // Make sure there is a kernel registered for the value preserving op
2189 // with the new input type.
2190 TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2191 ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2192
2193 AddToOptimizationQueue(new_producer);
2194 *simplified_node_name = new_consumer->name();
2195
2196 return Status::OK();
2197 }
2198
2199 private:
2200 // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2201 Status SetInputType(DataType dtype, NodeDef* node) {
2202 const OpDef* op_def = nullptr;
2203 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2204 const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2205 const string& type_attr_name = input_arg.type_attr();
2206 if (type_attr_name.empty()) {
2207 if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2208 return errors::InvalidArgument("Could not set input type of ",
2209 node->op(), " op to ",
2210 DataTypeString(dtype));
2211 } else {
2212 // Op has fixed input type that already matches dtype.
2213 return Status::OK();
2214 }
2215 }
2216 SetDataTypeToAttr(dtype, type_attr_name, node);
2217 return Status::OK();
2218 }
2219 // This optimization can be dangerous on devices other than CPU and
2220 // GPU. The transpose might not be implemented for image.type, or
2221 // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2222 bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2223 using absl::StrContains;
2224
2225 string task;
2226 string device;
2227
2228 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2229 (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2230 }
2231
IsFixedSizeType(DataType dtype)2232 bool IsFixedSizeType(DataType dtype) {
2233 return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2234 !kQuantizedTypes.Contains(dtype);
2235 }
2236 };
2237
2238 // Fold a multiply of a scalar into the following convolution. This folding
2239 // can jump across nodes that merely reorders data (such as reshape and
2240 // transpose). For example, we can optimize
2241 //
2242 //
2243 // Conv2D Conv2D
2244 // / \ / \
2245 // Transpose weights* -> Transpose Mul
2246 // | | / \
2247 // Mul | weights scale
2248 // / \ |
2249 // input scale** input
2250 //
2251 // *) weights must be a const
2252 // **) scale must be a const scalar
2253 //
2254 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2255 // constant-folded, also weights tend to be smaller than the activations.
2256 //
2257 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2258 // Conv?DBackpropInput.
2259 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2260 public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2261 explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2262 const ArithmeticOptimizerContext& ctx_ext)
2263 : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2264 ~FoldMultiplyIntoConv() override = default;
2265
IsSupported(const NodeDef * node) const2266 bool IsSupported(const NodeDef* node) const override {
2267 return IsConv2D(*node) || IsConv3D(*node);
2268 }
2269
TrySimplify(NodeDef * node,string * simplified_node_name)2270 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2271 #define TF_RETURN_IF_TRUE(...) \
2272 if ((__VA_ARGS__)) return Status::OK()
2273
2274 NodeDef* conv = node;
2275
2276 NodeDef* weights;
2277 TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2278
2279 // Fold the multiply to conv only when the weights are constant, so the
2280 // multiply can be constant-folded.
2281 //
2282 // TODO(jingyue): When the weights aren't constant, this should also help
2283 // performance a bit and memory usage a lot, since the weights tend to be
2284 // smaller than the activations.
2285 TF_RETURN_IF_TRUE(!IsConstant(*weights));
2286
2287 // Verify that this node was not already optimized.
2288 const string scaled_weights_node_name =
2289 OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2290 strings::StrCat("scaled", "_", conv->name()));
2291
2292 TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2293
2294 // Find the tail of value preserving chain entering the Conv node.
2295 NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2296 *ctx().nodes_to_preserve);
2297
2298 NodeDef* source;
2299 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2300
2301 // Check that value preserving chain is the only consumer of the Mul output.
2302 TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2303 TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2304 // And that Mul is not in the preserve set.
2305 TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2306
2307 const NodeDef* mul = source;
2308 int input_idx = 0;
2309 int scale_idx = 1;
2310 NodeDef* scale; // scalar multiplier for the input tensor
2311 NodeDef* input;
2312 TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2313 TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2314 if (!IsConstant(*scale) && IsConstant(*input)) {
2315 VLOG(3) << "Swapped inputs to mul";
2316 std::swap(scale_idx, input_idx);
2317 std::swap(scale, input);
2318 }
2319 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2320
2321 // Check that one of the inputs to mul is a constant scalar.
2322 const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2323 bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2324 scale_tensor.tensor_shape().dim_size() == 0;
2325 TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2326
2327 // Check that 'scale * weight' can be const folded.
2328 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2329 TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2330 TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2331 TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2332 weights->attr().at("dtype").type());
2333
2334 // At this point all preconditions are met, and we safely do the rewrite.
2335 VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2336 << " mul=" << mul->name() << " weights=" << weights->name();
2337
2338 // Create new node `scaled_weights`.
2339 NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2340 scaled_weights->set_op(source->op());
2341 scaled_weights->set_device(weights->device());
2342 (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2343 AddToOptimizationQueue(scaled_weights);
2344
2345 // Link in its inputs.
2346 scaled_weights->add_input(conv->input(1));
2347 ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2348 scaled_weights->add_input(mul->input(scale_idx));
2349 ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2350 ForwardControlDependencies(scaled_weights, {source});
2351
2352 // Update `conv`'s weights to `scaled_weights`.
2353 conv->set_input(1, scaled_weights->name());
2354 ctx().node_map->UpdateInput(conv->name(), weights->name(),
2355 scaled_weights->name());
2356 AddToOptimizationQueue(conv);
2357
2358 // Update `tail` node to bypass `mul` because it's folded to the weights.
2359 tail->set_input(0, mul->input(input_idx));
2360 ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2361 AddToOptimizationQueue(tail);
2362 *simplified_node_name = conv->name();
2363
2364 return Status::OK();
2365 #undef TF_RETURN_IF_TRUE
2366 }
2367 };
2368
2369 // Fold Transpose into matrix multiplication.
2370 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2371 public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2372 explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2373 const ArithmeticOptimizerContext& ctx_ext)
2374 : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2375 ~FoldTransposeIntoMatMul() override = default;
2376
IsSupported(const NodeDef * node) const2377 bool IsSupported(const NodeDef* node) const override {
2378 return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2379 }
2380
TrySimplify(NodeDef * node,string * simplified_node_name)2381 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2382 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2383 const string optimized_node_name = OptimizedNodeName(matmul);
2384 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2385
2386 NodeDef* a;
2387 NodeDef* b;
2388 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2389 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2390
2391 bool is_complex = false;
2392 if (node->op() != "SparseMatMul") {
2393 const DataType type = GetDataTypeFromAttr(*node, "T");
2394 is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2395 }
2396
2397 const std::set<string> foldable_transpose_ops =
2398 !is_complex
2399 ? std::set<string>{"ConjugateTranspose", "Transpose"}
2400 : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2401 : std::set<string>{"Transpose"});
2402
2403 const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2404 IsInnerMatrixTransposeNode(*a, ctx().node_map);
2405 const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2406 IsInnerMatrixTransposeNode(*b, ctx().node_map);
2407 if (!a_is_foldable && !b_is_foldable) return Status::OK();
2408
2409 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2410
2411 if (a_is_foldable) {
2412 const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2413 FlipBooleanAttr(attr_a, new_op);
2414 new_op->set_input(0, a->input(0));
2415 ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2416 } else {
2417 ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2418 }
2419
2420 if (b_is_foldable) {
2421 const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2422 FlipBooleanAttr(attr_b, new_op);
2423 new_op->set_input(1, b->input(0));
2424 ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2425 } else {
2426 ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2427 }
2428
2429 std::vector<const NodeDef*> deps_to_forward = {node};
2430 if (a_is_foldable) deps_to_forward.push_back(a);
2431 if (b_is_foldable) deps_to_forward.push_back(b);
2432 ForwardControlDependencies(new_op, deps_to_forward);
2433 *simplified_node_name = new_op->name();
2434
2435 return Status::OK();
2436 }
2437
2438 private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2439 void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2440 const bool old_value =
2441 !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2442 (*node->mutable_attr())[attr_name].set_b(!old_value);
2443 }
2444
2445 template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2446 bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2447 const T n = perm.size();
2448 if (n < 2) {
2449 return false;
2450 }
2451 for (T i = 0; i < n - 2; ++i) {
2452 if (perm[i] != i) {
2453 return false;
2454 }
2455 }
2456 return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2457 }
2458
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2459 bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2460 const NodeMap* node_map) {
2461 if (transpose_node.op() != "Transpose" &&
2462 transpose_node.op() != "ConjugateTranspose") {
2463 return false;
2464 }
2465 const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2466 std::vector<int> perm32;
2467 if (ValuesFromConstNode(*perm_node, &perm32)) {
2468 return IsInnerMatrixTranspose(perm32);
2469 }
2470 std::vector<int64> perm64;
2471 if (ValuesFromConstNode(*perm_node, &perm64)) {
2472 return IsInnerMatrixTranspose(perm64);
2473 }
2474 return false;
2475 }
2476 };
2477
2478 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2479 public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2480 explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2481 const ArithmeticOptimizerContext& ctx_ext)
2482 : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2483 ~FoldConjugateIntoTranspose() override = default;
2484
IsSupported(const NodeDef * node) const2485 bool IsSupported(const NodeDef* node) const override {
2486 return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2487 }
2488
TrySimplify(NodeDef * node,string * simplified_node_name)2489 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2490 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2491 const string optimized_node_name = OptimizedNodeName(matmul);
2492 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2493
2494 NodeDef* input;
2495 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2496
2497 const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2498 const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2499
2500 if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2501 IsConj(*conj_op)) {
2502 NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2503
2504 // Flip the type of transpose op to absorb the conjugation.
2505 new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2506 : "Transpose");
2507 new_op->set_input(0, input->input(0));
2508 ctx().node_map->UpdateInput(new_op->name(), node->name(),
2509 input->input(0));
2510 ForwardControlDependencies(new_op, {node, input});
2511 *simplified_node_name = new_op->name();
2512 }
2513
2514 return Status::OK();
2515 }
2516 };
2517
2518 // Replace Mul node with identical inputs with a Square.
2519 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2520 public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2521 explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2522 const ArithmeticOptimizerContext& ctx_ext)
2523 : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2524 ~ReplaceMulWithSquare() override = default;
2525
IsSupported(const NodeDef * node) const2526 bool IsSupported(const NodeDef* node) const override {
2527 if (!node || node->input_size() < 2) {
2528 // Invalid node
2529 return false;
2530 }
2531
2532 return IsAnyMul(*node) && node->input(0) == node->input(1);
2533 }
2534
TrySimplify(NodeDef * node,string * simplified_node_name)2535 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2536 const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2537 const string optimized_node_name = OptimizedNodeName(mul);
2538 if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2539
2540 const DataType type = GetDataTypeFromAttr(*node, "T");
2541 bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2542
2543 if (!is_complex || NodeIsOnCpu(*node)) {
2544 NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2545 new_square_node->set_op("Square");
2546 for (int i = 1; i < new_square_node->input_size(); ++i) {
2547 new_square_node->set_input(i - 1, new_square_node->input(i));
2548 }
2549 new_square_node->mutable_input()->RemoveLast();
2550 for (const string& input : new_square_node->input()) {
2551 ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2552 }
2553 *simplified_node_name = new_square_node->name();
2554 }
2555
2556 return Status::OK();
2557 }
2558 };
2559
2560 // Replace a combination of Mul with broadcasting by Tile. E.g. replace
2561 //
2562 // input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
2563 // Ones (1x22x2x48x2x64) -^
2564 //
2565 // with
2566 //
2567 // input -> Tile(1x22x2x48x2x64) -> output
2568 class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
2569 public:
ReplaceMulWithBroadcastByTile(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2570 explicit ReplaceMulWithBroadcastByTile(
2571 const GraphOptimizerContext& ctx,
2572 const ArithmeticOptimizerContext& ctx_ext)
2573 : ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
2574 ctx_ext) {}
2575 ~ReplaceMulWithBroadcastByTile() override = default;
2576
IsSupported(const NodeDef * node) const2577 bool IsSupported(const NodeDef* node) const override {
2578 return IsMul(*node) && !IsInPreserveSet(*node);
2579 }
2580
TrySimplify(NodeDef * node,string * simplified_node_name)2581 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2582 NodeDef *input, *ones;
2583 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2584 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2585 if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
2586 IsInPreserveSet(*ones)) {
2587 return Status::OK();
2588 }
2589
2590 // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
2591 if (IsConstant(*input) || !IsOnes(*ones)) return Status::OK();
2592
2593 // Avoid optimizing the same node twice
2594 const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2595 const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
2596 const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
2597 if (ctx().node_map->NodeExists(tile_node_name) ||
2598 ctx().node_map->NodeExists(const_node_name)) {
2599 return Status::OK();
2600 }
2601
2602 const std::vector<OpInfo::TensorProperties>& props =
2603 ctx().graph_properties->GetInputProperties(node->name());
2604 if (props.size() != 2) return Status::OK();
2605
2606 // Ignore ops where the shape doesn't change
2607 const TensorShapeProto& input_shape = props[0].shape();
2608 const TensorShapeProto& ones_shape = props[1].shape();
2609 TensorShapeProto output_shape;
2610 if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
2611 return Status::OK();
2612 }
2613 if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
2614 return Status::OK();
2615 }
2616
2617 // All inputs must have same input/output dimensions
2618 if (input_shape.dim_size() != output_shape.dim_size() ||
2619 ones_shape.dim_size() != output_shape.dim_size())
2620 return Status::OK();
2621
2622 // At this point all preconditions are met. Can proceed with rewrite.
2623 VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
2624 << "@" << output_shape << " ones=" << ones->name() << "@"
2625 << ones_shape << " input=" << input->name() << "@" << input_shape;
2626
2627 // 1. Create constant node with correct tile multiples
2628 Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
2629 for (int i = 0; i < output_shape.dim_size(); ++i) {
2630 int64_t size = output_shape.dim(i).size() / input_shape.dim(i).size();
2631 if (TF_PREDICT_FALSE(size >= INT_MAX)) {
2632 return Status(error::OUT_OF_RANGE, "int32 overflow");
2633 }
2634 multiples.flat<int32>()(i) = static_cast<int32>(size);
2635 }
2636
2637 NodeDef* const_node = AddEmptyNode(const_node_name);
2638 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2639 const_node->name(), TensorValue(&multiples), const_node));
2640 const_node->set_device(node->device());
2641 ForwardControlDependencies(const_node, {ones});
2642 AddToOptimizationQueue(const_node);
2643
2644 // 2. Replace multiply node with Tile(Const, input);
2645 const DataType type = GetDataTypeFromAttr(*node, "T");
2646 NodeDef* tile_node = AddEmptyNode(tile_node_name);
2647 tile_node->set_op("Tile");
2648 tile_node->set_device(node->device());
2649 SetDataTypeToAttr(type, "T", tile_node);
2650 SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
2651 tile_node->add_input(input->name());
2652 tile_node->add_input(const_node->name());
2653
2654 ForwardControlDependencies(tile_node, {node});
2655 *simplified_node_name = tile_node->name();
2656
2657 return Status::OK();
2658 }
2659
2660 protected:
IsOnes(const NodeDef & node) const2661 bool IsOnes(const NodeDef& node) const {
2662 if (!IsReallyConstant(node)) return false;
2663 if (node.attr().at("dtype").type() != DT_FLOAT) return false;
2664
2665 Tensor tensor;
2666 if (!tensor.FromProto(node.attr().at("value").tensor())) {
2667 return false;
2668 }
2669
2670 auto values = tensor.flat<float>();
2671 for (int i = 0; i < tensor.NumElements(); ++i) {
2672 if (values(i) != 1.0f) {
2673 return false;
2674 }
2675 }
2676
2677 return true;
2678 }
2679 };
2680
2681 // Image upsampling often produces an unnecessary reshape that is difficult to
2682 // eliminate in other stages. This stage reduces the number of dimensions
2683 // involved allowing the reshape to be removed.
2684 //
2685 // For example, given
2686 // B,W,H,C -> Reshape(B,W,1,H,1,C) -> Tile(1,1,2,1,2,1) -> Reshape(B,2W,2H,C)
2687 // this pass converts the sequence to
2688 // B,W,H,C -> Reshape(B,W,H,C) -> Tile(1,1,2,2) -> Reshape(B,2W,2H,C)
2689 //
2690 // The first reshape is now redundant and can be removed in a later pass.
2691 //
2692 // Note: This only optimizes the simple (but extremely common) case of 2D
2693 // upsampling.
2694 //
2695 // TODO(kkiningh): Generalize to more complex upsampling patterns.
2696 class ReduceUpsamplingDims : public ArithmeticOptimizerStage {
2697 public:
ReduceUpsamplingDims(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2698 explicit ReduceUpsamplingDims(const GraphOptimizerContext& ctx,
2699 const ArithmeticOptimizerContext& ctx_ext)
2700 : ArithmeticOptimizerStage("ReduceUpsamplingDims", ctx, ctx_ext) {}
2701 ~ReduceUpsamplingDims() override = default;
2702
IsSupported(const NodeDef * node) const2703 bool IsSupported(const NodeDef* node) const override {
2704 return IsReshape(*node) && !IsInPreserveSet(*node);
2705 }
2706
TrySimplify(NodeDef * node,string * simplified_node_name)2707 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2708 NodeDef* tile;
2709 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &tile));
2710 if (!IsTile(*tile) || IsInPreserveSet(*tile)) {
2711 return Status::OK();
2712 }
2713
2714 if (NumNonControlOutputs(*tile, *ctx().node_map) != 1) {
2715 // Optimization is only worthwile when there is a single output from Tile.
2716 // Otherwise, we need to insert addtional Reshape ops that can't be easily
2717 // removed.
2718 return Status::OK();
2719 }
2720
2721 NodeDef* reshape;
2722 TF_RETURN_IF_ERROR(GetInputNode(tile->input(0), &reshape));
2723 if (!IsReshape(*reshape) || IsInPreserveSet(*reshape)) {
2724 return Status::OK();
2725 }
2726
2727 NodeDef* multiples;
2728 TF_RETURN_IF_ERROR(GetInputNode(tile->input(1), &multiples));
2729
2730 NodeDef* shape;
2731 TF_RETURN_IF_ERROR(GetInputNode(reshape->input(1), &shape));
2732
2733 // Avoid optimizing the same nodes twice
2734 const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2735 const string new_reshape_name =
2736 OptimizedNodeName(scope_and_name, "Reshape");
2737 const string new_tile_name = OptimizedNodeName(scope_and_name, "Tile");
2738 const string new_multiples_name =
2739 OptimizedNodeName(scope_and_name, "Multiples");
2740 const string new_shape_name = OptimizedNodeName(scope_and_name, "Shape");
2741 if (ctx().node_map->NodeExists(new_reshape_name) ||
2742 ctx().node_map->NodeExists(new_tile_name) ||
2743 ctx().node_map->NodeExists(new_shape_name) ||
2744 ctx().node_map->NodeExists(new_multiples_name)) {
2745 return Status::OK();
2746 }
2747
2748 // Compuate updated multiples/shape values.
2749 AttrValue new_multiples_attr;
2750 if (!CreateUpdatedMultiplesProto(multiples,
2751 new_multiples_attr.mutable_tensor())) {
2752 return Status::OK();
2753 }
2754 AttrValue new_shape_attr;
2755 if (!CreateUpdatedShapeProto(shape, new_shape_attr.mutable_tensor())) {
2756 return Status::OK();
2757 }
2758
2759 // At this point the graph is validated and can be updated
2760 // Note: We can assume shape/multiples are DT_INT32 ony at this point since
2761 // they're checked in CreateUpdated*Proto()
2762
2763 // 1. Create the constant nodes used by the new Reshape/Tile nodes
2764 NodeDef* new_multiples = AddEmptyNode(new_multiples_name);
2765 new_multiples->set_op("Const");
2766 SetDataTypeToAttr(DT_INT32, "dtype", new_multiples);
2767 new_multiples->mutable_attr()->insert({"value", new_multiples_attr});
2768 new_multiples->set_device(multiples->device());
2769
2770 NodeDef* new_shape = AddEmptyNode(new_shape_name);
2771 new_shape->set_op("Const");
2772 SetDataTypeToAttr(DT_INT32, "dtype", new_shape);
2773 new_shape->mutable_attr()->insert({"value", new_shape_attr});
2774 new_shape->set_device(shape->device());
2775
2776 // 2. Create the new Reshape/Tile nodes
2777 NodeDef* new_reshape = AddEmptyNode(new_reshape_name);
2778 CopyReshapeWithInput(reshape, new_reshape, /*input=*/reshape->input(0),
2779 /*shape=*/new_shape->name());
2780 NodeDef* new_tile = AddEmptyNode(new_tile_name);
2781 CopyTileWithInput(tile, new_tile, /*input=*/new_reshape->name(),
2782 /*multiples=*/new_multiples->name());
2783
2784 // 3. Update consumer of original Tile node and add control
2785 node->set_input(0, new_tile->name());
2786 ctx().node_map->UpdateInput(node->name(), tile->name(), new_tile->name());
2787
2788 ForwardControlDependencies(new_tile, {tile});
2789 ForwardControlDependencies(new_multiples, {multiples});
2790 ForwardControlDependencies(new_reshape, {reshape});
2791 ForwardControlDependencies(new_shape, {shape});
2792
2793 *simplified_node_name = node->name();
2794 return Status::OK();
2795 }
2796
2797 private:
CreateUpdatedMultiplesProto(const NodeDef * node,TensorProto * proto)2798 bool CreateUpdatedMultiplesProto(const NodeDef* node, TensorProto* proto) {
2799 Tensor multiples;
2800 if (!GetTensorFromConstNode(node->name(), &multiples)) {
2801 return false;
2802 }
2803
2804 // Dimensions should be [X, Y, N, 1, M, 1]
2805 if (multiples.dtype() != DT_INT32 || multiples.NumElements() != 6) {
2806 return false;
2807 }
2808
2809 const auto& multiples_values = multiples.flat<int32>();
2810 if (multiples_values(3) != 1 || multiples_values(5) != 1) {
2811 return false;
2812 }
2813
2814 // Convert to [X, Y, N, M]
2815 Tensor new_multiples(DT_INT32, {4});
2816 new_multiples.flat<int32>()(0) = multiples_values(0);
2817 new_multiples.flat<int32>()(1) = multiples_values(1);
2818 new_multiples.flat<int32>()(2) = multiples_values(2);
2819 new_multiples.flat<int32>()(3) = multiples_values(4);
2820
2821 new_multiples.AsProtoTensorContent(proto);
2822 return true;
2823 }
2824
CreateUpdatedShapeProto(const NodeDef * node,TensorProto * proto)2825 bool CreateUpdatedShapeProto(const NodeDef* node, TensorProto* proto) {
2826 Tensor shape;
2827 if (!GetTensorFromConstNode(node->name(), &shape)) {
2828 return false;
2829 }
2830
2831 // Dimensions should be [B, W, 1, H, 1, C]
2832 if (shape.dtype() != DT_INT32 || shape.NumElements() != 6) {
2833 return false;
2834 }
2835
2836 const auto& shape_values = shape.flat<int32>();
2837 if (shape_values(2) != 1 || shape_values(4) != 1) {
2838 return false;
2839 }
2840
2841 // Convert to [B, W, H, C]
2842 Tensor new_shape(DT_INT32, {4});
2843 new_shape.flat<int32>()(0) = shape_values(0);
2844 new_shape.flat<int32>()(1) = shape_values(1);
2845 new_shape.flat<int32>()(2) = shape_values(3);
2846 new_shape.flat<int32>()(3) = shape_values(5);
2847
2848 new_shape.AsProtoTensorContent(proto);
2849 return true;
2850 }
2851
CopyReshapeWithInput(const NodeDef * reshape,NodeDef * new_reshape,const string & input,const string & shape)2852 void CopyReshapeWithInput(const NodeDef* reshape, NodeDef* new_reshape,
2853 const string& input, const string& shape) {
2854 new_reshape->set_op("Reshape");
2855 new_reshape->set_device(reshape->device());
2856 SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "T"), "T", new_reshape);
2857 SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "Tshape"), "Tshape",
2858 new_reshape);
2859
2860 new_reshape->add_input(input);
2861 ctx().node_map->AddOutput(NodeName(input), new_reshape->name());
2862 new_reshape->add_input(shape);
2863 ctx().node_map->AddOutput(NodeName(shape), new_reshape->name());
2864
2865 AddToOptimizationQueue(new_reshape);
2866 }
2867
CopyTileWithInput(const NodeDef * tile,NodeDef * new_tile,const string & input,const string & multiples)2868 void CopyTileWithInput(const NodeDef* tile, NodeDef* new_tile,
2869 const string& input, const string& multiples) {
2870 new_tile->set_op("Tile");
2871 new_tile->set_device(tile->device());
2872 SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "T"), "T", new_tile);
2873 SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "Tmultiples"), "Tmultiples",
2874 new_tile);
2875
2876 new_tile->add_input(input);
2877 ctx().node_map->AddOutput(NodeName(input), new_tile->name());
2878 new_tile->add_input(multiples);
2879 ctx().node_map->AddOutput(NodeName(multiples), new_tile->name());
2880
2881 AddToOptimizationQueue(new_tile);
2882 }
2883 };
2884
2885 // Replace a sequence of Pack nodes with identical inputs with Tile
2886 // For example, given a Tensor X with shape (I,J,K)
2887 // Let P(x, n) = Pack([x, x], axis=n)
2888 //
2889 // P(P(X, 2), 1)
2890 // = Tile(Reshape(Tile(Reshape(x,
2891 // [I, J, 1, K]), [1, 1, 2, 1]),
2892 // [I, 1, J, 2, K]), [1, 2, 1, 1, 1]))
2893 // = Tile(Reshape(x,
2894 // [I, 1, J, 1, K]), [1, 2, 1, 2, 1])
2895 // = Reshape(Tile(x, [1, 2, 2]), [I, 2, J, 2, K])
2896 //
2897 // The outermost reshape is often redundant and can be removed in another pass
2898 class ReplacePackWithTileReshape : public ArithmeticOptimizerStage {
2899 public:
ReplacePackWithTileReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2900 explicit ReplacePackWithTileReshape(const GraphOptimizerContext& ctx,
2901 const ArithmeticOptimizerContext& ctx_ext)
2902 : ArithmeticOptimizerStage("ReplacePackWithTileReshape", ctx, ctx_ext) {}
2903 ~ReplacePackWithTileReshape() override = default;
2904
IsSupported(const NodeDef * node) const2905 bool IsSupported(const NodeDef* node) const override {
2906 return IsPack(*node) && NumNonControlInputs(*node) > 1 &&
2907 !IsInPreserveSet(*node);
2908 }
2909
TrySimplify(NodeDef * node,string * simplified_node_name)2910 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2911 // 1. traverse the chain of Pack ops to get the original input
2912 NodeDef* input = node;
2913 std::vector<const NodeDef*> chain;
2914 while (IsPack(*input) && NumNonControlInputs(*node) > 1 &&
2915 !IsInPreserveSet(*input)) {
2916 // Only pack operations with all identical inputs are supported
2917 if (!AllRegularInputsEqual(*input)) {
2918 break;
2919 }
2920 chain.push_back(input);
2921 TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &input));
2922 }
2923
2924 // Must be at least two Pack operations to consider for replacement
2925 if (chain.empty()) {
2926 return Status::OK();
2927 }
2928
2929 // Avoid optimizing the same node twice
2930 const NodeScopeAndName node_scope_and_name =
2931 ParseNodeScopeAndName(node->name());
2932 const string new_const_name =
2933 OptimizedNodeName(node_scope_and_name, "Multiples");
2934 const string new_tile_name = OptimizedNodeName(node_scope_and_name, "Tile");
2935 const string new_shape_name =
2936 OptimizedNodeName(node_scope_and_name, "Shape");
2937 const string new_reshape_name =
2938 OptimizedNodeName(node_scope_and_name, "Reshape");
2939 if (ctx().node_map->NodeExists(new_const_name) ||
2940 ctx().node_map->NodeExists(new_tile_name) ||
2941 ctx().node_map->NodeExists(new_shape_name) ||
2942 ctx().node_map->NodeExists(new_reshape_name)) {
2943 return Status::OK();
2944 }
2945
2946 // 2. Calculate the multiples and shape tensor using the chain
2947 const OpInfo::TensorProperties* input_props;
2948 TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props));
2949 const TensorShapeProto& input_shape = input_props->shape();
2950 if (!PartialTensorShape(input_shape).IsFullyDefined()) {
2951 return Status::OK();
2952 }
2953 Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()}));
2954 TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples));
2955
2956 const OpInfo::TensorProperties* output_props;
2957 TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props));
2958 const TensorShapeProto& output_shape = output_props->shape();
2959 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
2960 return Status::OK();
2961 }
2962 Tensor output_shape_tensor(DT_INT32,
2963 TensorShape({output_shape.dim_size()}));
2964 for (int i = 0; i < output_shape.dim_size(); ++i) {
2965 output_shape_tensor.flat<int32>()(i) = output_shape.dim(i).size();
2966 }
2967
2968 // 3. Create constant node with correct multiples value
2969 NodeDef* new_const_node = AddEmptyNode(new_const_name);
2970 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2971 new_const_node->name(), TensorValue(&multiples), new_const_node));
2972 new_const_node->set_device(node->device());
2973 MaybeAddControlInput(input->name(), new_const_node, ctx().optimized_graph,
2974 ctx().node_map);
2975 AddToOptimizationQueue(new_const_node);
2976
2977 // 4. Replace the Pack node with Tile(Const(N), input);
2978 DataType dtype = GetDataTypeFromAttr(*node, "T");
2979 NodeDef* new_tile_node = AddEmptyNode(new_tile_name);
2980 new_tile_node->set_op("Tile");
2981 new_tile_node->set_device(node->device());
2982 SetDataTypeToAttr(dtype, "T", new_tile_node);
2983 SetDataTypeToAttr(DT_INT32, "Tmultiples", new_tile_node);
2984 new_tile_node->add_input(input->name());
2985 ctx().node_map->AddOutput(input->name(), new_tile_node->name());
2986 new_tile_node->add_input(new_const_node->name());
2987 ctx().node_map->AddOutput(new_const_node->name(), new_tile_node->name());
2988
2989 // Tile inherits all control dependencies from the original pack chain
2990 ForwardControlDependencies(new_tile_node, chain);
2991 AddToOptimizationQueue(new_tile_node);
2992
2993 // 5. Add a new Reshape node to preserve the existing shape
2994 NodeDef* new_shape_node = AddEmptyNode(new_shape_name);
2995 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2996 new_shape_node->name(), TensorValue(&output_shape_tensor),
2997 new_shape_node));
2998 new_shape_node->set_device(node->device());
2999 MaybeAddControlInput(input->name(), new_shape_node, ctx().optimized_graph,
3000 ctx().node_map);
3001 AddToOptimizationQueue(new_shape_node);
3002
3003 NodeDef* new_reshape_node = AddEmptyNode(new_reshape_name);
3004 new_reshape_node->set_op("Reshape");
3005 new_reshape_node->set_device(node->device());
3006 SetDataTypeToAttr(dtype, "T", new_reshape_node);
3007 SetDataTypeToAttr(DT_INT32, "Tshape", new_reshape_node);
3008 new_reshape_node->add_input(new_tile_node->name());
3009 ctx().node_map->AddOutput(new_tile_node->name(), new_reshape_node->name());
3010 new_reshape_node->add_input(new_shape_node->name());
3011 ctx().node_map->AddOutput(new_shape_node->name(), new_reshape_node->name());
3012
3013 *simplified_node_name = new_reshape_node->name();
3014
3015 return Status::OK();
3016 }
3017
3018 protected:
CalculateMultiplesFromChain(const std::vector<const NodeDef * > & chain,Tensor * multiples)3019 Status CalculateMultiplesFromChain(const std::vector<const NodeDef*>& chain,
3020 Tensor* multiples) {
3021 // Keep track of how the multiples correspond to each shape dimension.
3022 // For example, given Stack([x, x], axis=1) with rank(x) = 3, we start with
3023 // multiples=[1, 1, 1] , dims=[0, 1, 2]
3024 // After processing the stack op
3025 // multiples=[1, 2, 1] , dims=[0, 1, 1, 2]
3026 std::vector<int32> dims(multiples->NumElements());
3027 std::iota(dims.begin(), dims.end(), 0);
3028
3029 for (int i = 0; i < multiples->NumElements(); ++i) {
3030 multiples->flat<int32>()(i) = 1;
3031 }
3032
3033 for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
3034 AttrSlice attrs(**it);
3035 int64_t axis, n;
3036 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &axis));
3037 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &n));
3038
3039 if (axis >= dims.size()) {
3040 // We don't handle the case where Pack is performed on the last axis,
3041 // e.g. Pack([x, x], axis=3) where rank(x) == 3
3042 return Status(error::OUT_OF_RANGE, "axis value out of range of dims");
3043 }
3044
3045 int64_t m = multiples->flat<int32>()(dims[axis]) * n;
3046 if (TF_PREDICT_FALSE(m > INT_MAX)) {
3047 return Status(error::OUT_OF_RANGE, "int32 overflow");
3048 }
3049 multiples->flat<int32>()(dims[axis]) = static_cast<int32>(m);
3050
3051 // Copy index from immediate right of inserted axis
3052 dims.insert(dims.begin() + axis, dims[axis]);
3053 }
3054
3055 return Status::OK();
3056 }
3057 };
3058
3059 // Simplify aggregation (e.g. AddN) nodes:
3060 //
3061 // 1. Discard aggregate nodes with a single input and no control dependencies.
3062 //
3063 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
3064 // deduping or other rewrites) so we can get rid of the sum entirely.
3065 //
3066 // The expression (using AddN as an example of an aggregate op):
3067 // AddN(x, x, x, ... ,x)
3068 // <-- N terms -->
3069 // can be rewritten to:
3070 // Mul(Const(N), x))
3071 //
3072 class SimplifyAggregation : public ArithmeticOptimizerStage {
3073 public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3074 explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
3075 const ArithmeticOptimizerContext& ctx_ext)
3076 : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
3077 ~SimplifyAggregation() override = default;
3078
IsSupported(const NodeDef * node) const3079 bool IsSupported(const NodeDef* node) const override {
3080 return IsAggregate(*node) && HasRegularInputs(*node) &&
3081 GetDataTypeFromAttr(*node, "T") !=
3082 DT_VARIANT; // TODO(b/119787146): Enable for variants.
3083 }
3084
TrySimplify(NodeDef * node,string * simplified_node_name)3085 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3086 // 1. Discard aggregate nodes with a single input and no control deps.
3087 if (node->input_size() == 1) {
3088 *simplified_node_name = node->input(0);
3089 return Status::OK();
3090 }
3091
3092 // 2. Rewrite aggregations of N >= 2 identical terms.
3093
3094 // All non-control inputs must be identical.
3095 bool all_equal = true;
3096 int num_inputs = 1;
3097 for (int i = 1; i < node->input_size(); ++i) {
3098 if (IsControlInput(node->input(i))) break;
3099 ++num_inputs;
3100 if (node->input(i) != node->input(0)) {
3101 all_equal = false;
3102 break;
3103 }
3104 }
3105 if (!all_equal) return Status::OK();
3106
3107 // And node should not be optimized earlier.
3108 const NodeScopeAndName node_scope_and_name =
3109 ParseNodeScopeAndName(node->name());
3110 const string optimized_const_name =
3111 OptimizedNodeName(node_scope_and_name, "Const");
3112 const string optimized_mul_name =
3113 OptimizedNodeName(node_scope_and_name, "Mul");
3114
3115 bool is_already_optimized =
3116 ctx().node_map->NodeExists(optimized_const_name) ||
3117 ctx().node_map->NodeExists(optimized_mul_name);
3118
3119 if (is_already_optimized) return Status::OK();
3120
3121 // At this point all preconditions are met, and we safely do the rewrite.
3122 VLOG(3) << "Simplify aggregation with identical inputs: node="
3123 << node->name() << " num_inputs=" << num_inputs;
3124
3125 // 1. Create constant node with value N.
3126 const auto type = GetDataTypeFromAttr(*node, "T");
3127 Tensor t(type, TensorShape({}));
3128 Status status = SetTensorValue(type, num_inputs, &t);
3129 if (!status.ok()) {
3130 return errors::Internal("Failed to create const node: ",
3131 status.error_message());
3132 }
3133
3134 TensorValue value(&t);
3135 NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
3136 status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
3137 new_const_node);
3138 if (!status.ok()) {
3139 return errors::Internal("Failed to create const node: ",
3140 status.error_message());
3141 }
3142 new_const_node->set_device(node->device());
3143 MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
3144 ctx().optimized_graph, ctx().node_map);
3145 AddToOptimizationQueue(new_const_node);
3146
3147 // 2. Replace the aggregate node with Mul(Const(N), x).
3148 NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
3149 new_mul_node->set_op("Mul");
3150 new_mul_node->set_device(node->device());
3151 SetDataTypeToAttr(type, "T", new_mul_node);
3152 new_mul_node->add_input(new_const_node->name());
3153 ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
3154 new_mul_node->add_input(node->input(0));
3155 ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
3156
3157 ForwardControlDependencies(new_mul_node, {node});
3158 *simplified_node_name = new_mul_node->name();
3159
3160 return Status::OK();
3161 }
3162 };
3163
3164 class ConvertPowStage : public ArithmeticOptimizerStage {
3165 public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3166 explicit ConvertPowStage(const GraphOptimizerContext& ctx,
3167 const ArithmeticOptimizerContext& ctx_ext)
3168 : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
3169
IsSupported(const NodeDef * node) const3170 bool IsSupported(const NodeDef* node) const override {
3171 return IsPow(*node) &&
3172 ctx().graph_properties->HasOutputProperties(node->name()) &&
3173 ctx().graph_properties->HasInputProperties(node->name());
3174 }
3175
TrySimplify(NodeDef * node,string * simplified_node_name)3176 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3177 Tensor pow;
3178 if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
3179 complex128 prev, curr;
3180 for (int i = 0; i < pow.NumElements(); ++i) {
3181 if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
3182 // input data type is not supported by Pow. Skip.
3183 return Status::OK();
3184 }
3185 if (i != 0 && curr != prev) {
3186 // pow has different values on different elements. Skip.
3187 return Status::OK();
3188 }
3189 prev = curr;
3190 }
3191 NodeDef *x, *y;
3192 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
3193 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
3194 const auto& value_props =
3195 ctx().graph_properties->GetInputProperties(node->name())[0];
3196 const TensorShapeProto& output_shape =
3197 ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
3198 if (curr == complex128(2, 0)) {
3199 node->set_op("Square");
3200 node->set_input(1, AsControlDependency(y->name()));
3201 AddToOptimizationQueue(node);
3202 AddToOptimizationQueue(y);
3203 } else if (curr == complex128(3, 0)) {
3204 // TODO(courbet): Use 'Cube' when it's added to TF ops.
3205 if (NodeIsOnCpu(*node)) {
3206 // We create an inner square node: inner_square = square(x)
3207 const NodeScopeAndName scope_and_name =
3208 ParseNodeScopeAndName(node->name());
3209 const string inner_square_name =
3210 OptimizedNodeName(scope_and_name, "_inner");
3211 NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
3212 if (inner_square_node == nullptr) {
3213 inner_square_node = AddCopyNode(inner_square_name, node);
3214 inner_square_node->set_op("Square");
3215 inner_square_node->mutable_input()->RemoveLast();
3216 }
3217 ctx().node_map->AddOutput(x->name(), inner_square_node->name());
3218 // We modify `node`: node = mul(x, inner_square);
3219 node->set_op("Mul");
3220 node->set_input(1, inner_square_node->name());
3221 node->add_input(AsControlDependency(y->name()));
3222
3223 AddToOptimizationQueue(node);
3224 AddToOptimizationQueue(inner_square_node);
3225 AddToOptimizationQueue(y);
3226 }
3227 } else if (curr == complex128(1, 0) &&
3228 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
3229 // Pow could be used to broadcast, so make sure the shapes of the two
3230 // arguments are identical before replacing Pow with Identity.
3231 node->set_op("Identity");
3232 node->set_input(1, AsControlDependency(y->name()));
3233 AddToOptimizationQueue(node);
3234 AddToOptimizationQueue(y);
3235 } else if (curr == complex128(0.5, 0)) {
3236 node->set_op("Sqrt");
3237 node->set_input(1, AsControlDependency(y->name()));
3238 AddToOptimizationQueue(node);
3239 AddToOptimizationQueue(y);
3240 } else if (curr == complex128(0, 0) &&
3241 ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
3242 PartialTensorShape(output_shape).IsFullyDefined()) {
3243 const auto dtype = node->attr().at("T").type();
3244 Tensor ones(dtype, output_shape);
3245 for (int i = 0; i < ones.NumElements(); ++i) {
3246 TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
3247 }
3248 node->set_op("Const");
3249 (*node->mutable_attr())["dtype"].set_type(dtype);
3250 node->mutable_attr()->erase("T");
3251 ones.AsProtoTensorContent(
3252 (*node->mutable_attr())["value"].mutable_tensor());
3253 node->set_input(0, AsControlDependency(x->name()));
3254 node->set_input(1, AsControlDependency(y->name()));
3255 AddToOptimizationQueue(node);
3256 AddToOptimizationQueue(x);
3257 AddToOptimizationQueue(y);
3258 } else if (curr == complex128(-0.5, 0)) {
3259 node->set_op("Rsqrt");
3260 node->set_input(1, AsControlDependency(y->name()));
3261 AddToOptimizationQueue(node);
3262 AddToOptimizationQueue(y);
3263 } else if (curr == complex128(-1, 0)) {
3264 node->set_op("Reciprocal");
3265 node->set_input(1, AsControlDependency(y->name()));
3266 AddToOptimizationQueue(node);
3267 AddToOptimizationQueue(y);
3268 }
3269 return Status::OK();
3270 }
3271
3272 private:
SetElementToOne(int i,Tensor * t)3273 Status SetElementToOne(int i, Tensor* t) {
3274 switch (t->dtype()) {
3275 case DT_INT32:
3276 t->flat<int32>()(i) = 1;
3277 return Status::OK();
3278 case DT_INT64:
3279 t->flat<int64>()(i) = 1L;
3280 return Status::OK();
3281 case DT_FLOAT:
3282 t->flat<float>()(i) = 1.0f;
3283 return Status::OK();
3284 case DT_DOUBLE:
3285 t->flat<double>()(i) = 1.0;
3286 return Status::OK();
3287 case DT_COMPLEX64:
3288 t->flat<complex64>()(i) = complex64(1);
3289 return Status::OK();
3290 case DT_COMPLEX128:
3291 t->flat<complex128>()(i) = complex128(1);
3292 return Status::OK();
3293 default:
3294 return errors::InvalidArgument("Invalid data type: ", t->dtype());
3295 }
3296 }
3297 };
3298
3299 class ConvertLog1pStage : public ArithmeticOptimizerStage {
3300 public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3301 explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
3302 const ArithmeticOptimizerContext& ctx_ext)
3303 : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
3304 ~ConvertLog1pStage() override = default;
3305
IsSupported(const NodeDef * node) const3306 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
3307
TrySimplify(NodeDef * node,string * simplified_node_name)3308 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3309 NodeDef* input;
3310 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
3311 if (!IsAdd(*input)) {
3312 return Status::OK();
3313 }
3314
3315 if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
3316 return Status::OK();
3317 }
3318
3319 bool modified = false;
3320 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
3321 if (!modified) {
3322 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
3323 }
3324 if (modified) {
3325 *simplified_node_name = node->name();
3326 }
3327 return Status::OK();
3328 }
3329
3330 private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)3331 Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
3332 bool* modified) {
3333 const auto& t =
3334 ctx().graph_properties->GetInputProperties(add_node->name())[i];
3335 const auto& c =
3336 ctx().graph_properties->GetInputProperties(add_node->name())[j];
3337 for (int k = 0; k < c.shape().dim_size(); ++k) {
3338 // Skip if c shape is not fully determined.
3339 if (c.shape().dim(k).size() < 0) {
3340 return Status::OK();
3341 }
3342 }
3343 TensorShapeProto broadcast_shape;
3344 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3345 return Status::OK();
3346 }
3347 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3348 // skip if the non-constant tensor doesn't have the same shape after
3349 // broadcast.
3350 return Status::OK();
3351 }
3352 Tensor constant;
3353 if (GetTensorFromConstNode(add_node->input(j), &constant)) {
3354 complex128 element;
3355 // TODO(rmlarsen): Refactor the more general IsOnes from
3356 // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
3357 // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
3358 // and Neg(-1) to 1.
3359 for (int k = 0; k < constant.NumElements(); ++k) {
3360 if (!GetElementUnexhaustive(constant, k,
3361 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3362 DT_COMPLEX64, DT_COMPLEX128},
3363 &element)) {
3364 // input data type is not supported by log1p. Skip.
3365 return Status::OK();
3366 }
3367 if (element != complex128(1)) {
3368 // current element is not 1. Skip.
3369 return Status::OK();
3370 }
3371 }
3372 NodeDef *x, *y;
3373 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
3374 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
3375 node->set_op("Log1p");
3376 node->set_input(0, add_node->input(i));
3377 node->add_input(AsControlDependency(y->name()));
3378 ForwardControlDependencies(node, {add_node});
3379
3380 AddToOptimizationQueue(node);
3381 AddToOptimizationQueue(add_node);
3382 AddToOptimizationQueue(x);
3383 AddToOptimizationQueue(y);
3384 *modified = true;
3385 }
3386 return Status::OK();
3387 }
3388 };
3389
3390 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
3391 public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3392 explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
3393 const ArithmeticOptimizerContext& ctx_ext)
3394 : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
3395 ~ConvertExpm1Stage() override = default;
3396
IsSupported(const NodeDef * node) const3397 bool IsSupported(const NodeDef* node) const override {
3398 if (!IsSub(*node)) return false;
3399
3400 NodeDef* input;
3401 if (!GetInputNode(node->input(0), &input).ok()) return false;
3402
3403 return IsExp(*input);
3404 }
3405
TrySimplify(NodeDef * node,string * simplified_node_name)3406 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3407 if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
3408 return Status::OK();
3409 }
3410 const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
3411 const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
3412 TensorShapeProto broadcast_shape;
3413 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3414 return Status::OK();
3415 }
3416 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3417 // skip if the non-constant tensor doesn't have the same shape after
3418 // broadcast.
3419 return Status::OK();
3420 }
3421 Tensor constant;
3422 if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
3423 // TODO(rmlarsen): Use the more general IsOnes helper here.
3424 complex128 element;
3425 for (int k = 0; k < constant.NumElements(); ++k) {
3426 if (!GetElementUnexhaustive(constant, k,
3427 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3428 DT_COMPLEX64, DT_COMPLEX128},
3429 &element)) {
3430 // input data type is not supported by expm1. Skip.
3431 return Status::OK();
3432 }
3433 if (element != complex128(1)) {
3434 // current element is not 1. Skip.
3435 return Status::OK();
3436 }
3437 }
3438 NodeDef* exp;
3439 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
3440 NodeDef *exp_input, *ones;
3441 TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
3442 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
3443 node->set_op("Expm1");
3444 node->set_input(0, exp->input(0));
3445 node->set_input(1, AsControlDependency(ones->name()));
3446 ForwardControlDependencies(node, {exp});
3447
3448 AddToOptimizationQueue(node);
3449 AddToOptimizationQueue(exp);
3450 AddToOptimizationQueue(exp_input);
3451 AddToOptimizationQueue(ones);
3452 *simplified_node_name = node->name();
3453 return Status::OK();
3454 }
3455 };
3456
3457 // Performs conversions like:
3458 // Max(Sqrt(x)) => Sqrt(Max(x))
3459 // Checks for a max/min reduction over element-wise monotonic functions, such
3460 // as Sqrt, Sigmoid, Tanh, etc.
3461 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
3462 public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3463 explicit OptimizeMaxOrMinOfMonotonicStage(
3464 const GraphOptimizerContext& ctx,
3465 const ArithmeticOptimizerContext& ctx_ext)
3466 : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
3467 ctx_ext) {}
3468 ~OptimizeMaxOrMinOfMonotonicStage() override = default;
3469
IsSupported(const NodeDef * node) const3470 bool IsSupported(const NodeDef* node) const override {
3471 return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
3472 IsArgMax(*node) || IsArgMin(*node);
3473 }
3474
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3475 Status TrySimplify(NodeDef* reduction_node,
3476 string* simplified_node_name) override {
3477 if (IsInPreserveSet(*reduction_node)) {
3478 return Status::OK();
3479 }
3480
3481 NodeDef* inner_function;
3482 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
3483
3484 NodeDef* inner_function_input = nullptr;
3485 if (inner_function->input_size() > 0) {
3486 TF_RETURN_IF_ERROR(
3487 GetInputNode(inner_function->input(0), &inner_function_input));
3488 }
3489
3490 // Optimize only if:
3491 // 0. inner_function is not in the preserve set,
3492 // 1. inner_function's Op is element-wise monotonic
3493 // 2. inner_function's output is not being consumed elsewhere.
3494 // 3. is monotonic increasing if reduction_node is a pooling operation
3495 // since we don't have MinPool operations.
3496 // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
3497 // or BiasAdd. This pattern will be fused later by remapper.
3498 auto can_be_fused_by_remapper = [](const NodeDef& consumer,
3499 const NodeDef& producer) -> bool {
3500 if (IsRelu(consumer) || IsRelu6(consumer)) {
3501 if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
3502 return true;
3503 }
3504 }
3505 return false;
3506 };
3507 bool is_non_decreasing = false;
3508 if (!IsInPreserveSet(*inner_function) &&
3509 IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
3510 ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
3511 (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
3512 !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
3513 // Swap the first inputs of the inner function Op & the reduction Op.
3514 NodeDef* inner_input;
3515 TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
3516 reduction_node->set_input(0, inner_input->name());
3517 ctx().node_map->UpdateInput(reduction_node->name(),
3518 inner_function->name(), inner_input->name());
3519 inner_function->set_input(0, reduction_node->name());
3520 TF_RETURN_IF_ERROR(
3521 UpdateConsumers(reduction_node, inner_function->name()));
3522 ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
3523 reduction_node->name());
3524 if (!is_non_decreasing) {
3525 // Flip Min<->Max if the function is non-increasing, e.g.
3526 // Max(Neg(x)) = Neg(Min(x)).
3527 const string opposite = FlipMinMax(*reduction_node);
3528 reduction_node->set_op(opposite);
3529 }
3530
3531 if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
3532 // ArgMax(Sqrt(x)) = ArgMax(x)
3533 inner_function->set_op("Identity");
3534 }
3535
3536 AddToOptimizationQueue(reduction_node);
3537 AddToOptimizationQueue(inner_function);
3538 AddToOptimizationQueue(inner_input);
3539 }
3540 return Status::OK();
3541 }
3542
3543 private:
FlipMinMax(const NodeDef & node)3544 string FlipMinMax(const NodeDef& node) {
3545 const string& op = node.op();
3546 if (IsAnyMax(node) || IsArgMax(node)) {
3547 return str_util::StringReplace(op, "Max", "Min", false);
3548 } else {
3549 return str_util::StringReplace(op, "Min", "Max", false);
3550 }
3551 }
3552 };
3553
3554 // Replace a chain of type&shape preserving unary ops with a
3555 // '_UnaryOpsComposition' node.
3556 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
3557 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
3558 class UnaryOpsComposition : public ArithmeticOptimizerStage {
3559 public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3560 explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
3561 const ArithmeticOptimizerContext& ctx_ext)
3562 : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
3563 // WARN: This should be consistent with unary_ops_composition.cc.
3564 // clang-format off
3565 supported_ops_ = {// Ops defined via Eigen scalar ops.
3566 {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3567 {"Acos", {DT_FLOAT, DT_DOUBLE}},
3568 {"Acosh", {DT_FLOAT, DT_DOUBLE}},
3569 {"Asin", {DT_FLOAT, DT_DOUBLE}},
3570 {"Asinh", {DT_FLOAT, DT_DOUBLE}},
3571 {"Atan", {DT_FLOAT, DT_DOUBLE}},
3572 {"Atanh", {DT_FLOAT, DT_DOUBLE}},
3573 {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3574 {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3575 {"Cosh", {DT_FLOAT, DT_DOUBLE}},
3576 {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3577 {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3578 {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3579 {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3580 {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3581 {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3582 {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3583 {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3584 {"Rint", {DT_FLOAT, DT_DOUBLE}},
3585 {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3586 {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3587 {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3588 {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3589 {"Sinh", {DT_FLOAT, DT_DOUBLE}},
3590 {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3591 {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3592 {"Tan", {DT_FLOAT, DT_DOUBLE}},
3593 {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3594 // Additional ops that are not part of the Eigen.
3595 {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3596 {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3597 {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3598 {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3599 // clang-format on
3600 }
3601 ~UnaryOpsComposition() override = default;
3602
IsSupported(const NodeDef * node) const3603 bool IsSupported(const NodeDef* node) const override {
3604 return CanOptimize(*node) &&
3605 // Check that this node was not already a root of a fused chain. If
3606 // graph optimization runs twice without pruning in between,
3607 // fused_nodes_ will not have this information.
3608 !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3609 }
3610
TrySimplify(NodeDef * root,string * simplified_node_name)3611 Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3612 TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3613 DataType dtype = root->attr().at("T").type();
3614
3615 // Keep a trace of all supported input nodes that can be fused together.
3616 std::vector<string> op_nodes = {root->name()};
3617 std::vector<string> op_names = {root->op()};
3618
3619 // Check if we should follow input(0) while building an op composition.
3620 const auto predicate_fn = [&](const NodeDef& input) {
3621 if (input.name() == root->name()) return true;
3622
3623 bool follow_input_node =
3624 dtype == GetDataTypeFromAttr(input, "T") &&
3625 NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3626 CanOptimize(input);
3627
3628 if (follow_input_node) {
3629 op_nodes.push_back(input.name());
3630 op_names.push_back(input.op());
3631 }
3632
3633 return follow_input_node;
3634 };
3635
3636 NodeDef* last_op = GetTailOfChain(
3637 *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3638
3639 // We were not able to find a chain that can be replaced.
3640 if (op_names.size() == 1) return Status::OK();
3641
3642 // Do not add fused nodes to any other chain.
3643 std::for_each(op_nodes.begin(), op_nodes.end(),
3644 [this](const string& name) { AddToFusedNodes(name); });
3645
3646 // Reverse the trace to get correct composition computation order.
3647 std::reverse(op_names.begin(), op_names.end());
3648
3649 VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3650 << absl::StrJoin(op_names, ", ") << "]";
3651
3652 NodeDef* composition_node = ctx().optimized_graph->add_node();
3653 composition_node->set_name(OptimizedNodeName(*root));
3654 composition_node->set_op("_UnaryOpsComposition");
3655 composition_node->add_input(last_op->input(0));
3656 composition_node->set_device(root->device());
3657
3658 auto attr = composition_node->mutable_attr();
3659 SetAttrValue(dtype, &(*attr)["T"]);
3660 SetAttrValue(op_names, &(*attr)["op_names"]);
3661
3662 ctx().node_map->AddNode(composition_node->name(), composition_node);
3663 ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3664 composition_node->name());
3665
3666 *simplified_node_name = composition_node->name();
3667
3668 return Status::OK();
3669 }
3670
3671 private:
CanOptimize(const NodeDef & node) const3672 bool CanOptimize(const NodeDef& node) const {
3673 DataType dtype = GetDataTypeFromAttr(node, "T");
3674 if (!IsSupported(node.op(), dtype)) {
3675 return false;
3676 }
3677 if (IsInPreserveSet(node)) {
3678 return false;
3679 }
3680 if (!NodeIsOnCpu(node)) {
3681 return false;
3682 }
3683 if (NodeIsAlreadyFused(node)) {
3684 return false;
3685 }
3686 return !(IsDrivenByControlDependency(node) ||
3687 DrivesControlDependency(node));
3688 }
3689
NodeIsAlreadyFused(const NodeDef & node) const3690 bool NodeIsAlreadyFused(const NodeDef& node) const {
3691 return fused_nodes_.count(node.name()) > 0;
3692 }
3693
OptimizedNodeName(const NodeDef & node) const3694 string OptimizedNodeName(const NodeDef& node) const {
3695 return strings::StrCat(node.name(), "/unary_ops_composition");
3696 }
3697
AddToFusedNodes(const string & name)3698 void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3699
3700 // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3701 bool IsSupported(const string& op_name, DataType dtype) const {
3702 const auto it = supported_ops_.find(op_name);
3703 return it != supported_ops_.end() && it->second.count(dtype) > 0;
3704 }
3705
3706 std::unordered_map<string, std::set<DataType>> supported_ops_;
3707 std::unordered_set<string> fused_nodes_;
3708 };
3709
3710 // Replace operations of the form:
3711 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3712 // with
3713 // a_i
3714 // when the strided slice index `i` is applied in the k'th axis.
3715 //
3716 // Similarly, replace operations of the form:
3717 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3718 // with
3719 // expand_dims(a_i, axis=k)
3720 // where the slice operator can be StridedSlice or Slice.
3721 //
3722 // TODO(ebrevdo): Extend to also replace operations of the form
3723 // concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3724 // with
3725 // a_i,
3726 // when
3727 // s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3728 // and slicing is in the k'th axis.
3729 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3730 public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3731 explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3732 const ArithmeticOptimizerContext& ctx_ext)
3733 : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3734 ctx_ext) {}
3735 ~RemoveStackSliceSameAxis() override = default;
3736
IsSupported(const NodeDef * node) const3737 bool IsSupported(const NodeDef* node) const override {
3738 return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3739 }
3740
TrySimplify(NodeDef * node,string * simplified_node_name)3741 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3742 // *node is a StridedSlice NodeDef.
3743 NodeDef* pack;
3744
3745 // Get the input and see if it's a Pack op.
3746 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3747 if (!IsPack(*pack)) return Status::OK();
3748
3749 bool return_early;
3750 PartialTensorShape pack_output_shape;
3751 int pack_axis;
3752 TF_RETURN_IF_ERROR(
3753 CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3754 if (return_early) return Status::OK();
3755
3756 int64_t slice_start_value;
3757 bool found;
3758 bool must_expand_dims;
3759 TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3760 &slice_start_value, &found,
3761 &must_expand_dims));
3762 if (!found) return Status::OK();
3763
3764 return RewriteGraph(node, pack, slice_start_value, pack_axis,
3765 must_expand_dims, simplified_node_name);
3766 }
3767
3768 protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3769 Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3770 PartialTensorShape* pack_output_shape, int* pack_axis,
3771 bool* return_early) {
3772 *return_early = true;
3773 TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3774
3775 *pack_axis = pack->attr().at("axis").i();
3776 auto slice_properties =
3777 ctx().graph_properties->GetInputProperties(node->name());
3778 if (slice_properties.empty() ||
3779 slice_properties[0].shape().unknown_rank()) {
3780 return Status::OK();
3781 }
3782 *pack_output_shape = slice_properties[0].shape();
3783 const int pack_output_rank = pack_output_shape->dims();
3784 if (*pack_axis < 0) {
3785 *pack_axis += pack_output_rank;
3786 }
3787 if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3788 return errors::InvalidArgument(
3789 "Pack node (", pack->name(),
3790 ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3791 }
3792 *return_early = false;
3793 return Status::OK();
3794 }
3795
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3796 Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3797 const PartialTensorShape& pack_output_shape,
3798 int pack_axis, int64* slice_start_value, bool* found,
3799 bool* must_expand_dims) {
3800 *found = false;
3801 if (IsSlice(*node)) {
3802 *must_expand_dims = true;
3803 return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3804 slice_start_value, found);
3805 } else {
3806 return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3807 slice_start_value, found, must_expand_dims);
3808 }
3809 }
3810
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found)3811 Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3812 const PartialTensorShape& pack_output_shape,
3813 int pack_axis, int64* slice_start_value,
3814 bool* found) {
3815 NodeDef* slice_begin;
3816 NodeDef* slice_size;
3817 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3818 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3819 for (const auto* n : {slice_begin, slice_size}) {
3820 if (!IsReallyConstant(*n)) return Status::OK();
3821 }
3822
3823 Tensor slice_begin_t;
3824 Tensor slice_size_t;
3825 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3826 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3827 return Status::OK();
3828 }
3829 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3830 if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3831 return Status::OK();
3832 }
3833
3834 auto copy_tensor_values_to_vector =
3835 [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3836 if (t.dtype() == DT_INT32) {
3837 auto t_flat = t.flat<int32>();
3838 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3839 } else if (t.dtype() == DT_INT64) {
3840 auto t_flat = t.flat<int64>();
3841 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3842 } else {
3843 return errors::InvalidArgument("Node ", node->name(),
3844 " has invalid type for Index attr: ",
3845 DataTypeString(t.dtype()));
3846 }
3847 return Status::OK();
3848 };
3849
3850 gtl::InlinedVector<int64, 4> slice_begin_vec;
3851 gtl::InlinedVector<int64, 4> slice_size_vec;
3852 TF_RETURN_IF_ERROR(
3853 copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3854 TF_RETURN_IF_ERROR(
3855 copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3856
3857 if (slice_begin_vec.size() != slice_size_vec.size()) {
3858 return errors::InvalidArgument("Node ", node->name(),
3859 " has mismatched lengths for begin (",
3860 slice_begin_vec.size(), ") and size (",
3861 slice_size_vec.size(), ") vectors.");
3862 }
3863 int slice_begin_vec_size = slice_begin_vec.size();
3864 if (!pack_output_shape.unknown_rank() &&
3865 slice_begin_vec_size != pack_output_shape.dims()) {
3866 return Status::OK();
3867 }
3868 if (pack_axis >= slice_begin_vec_size) {
3869 return errors::InvalidArgument(
3870 "Input to node ", node->name(), " had pack_axis ", pack_axis,
3871 " but rank was ", slice_begin_vec_size, ".");
3872 }
3873
3874 *slice_start_value = slice_begin_vec[pack_axis];
3875 if (slice_size_vec[pack_axis] != 1) {
3876 // Not slicing a single value out.
3877 return Status::OK();
3878 }
3879
3880 for (int i = 0; i < slice_begin_vec_size; ++i) {
3881 if (i != pack_axis) {
3882 if (slice_begin_vec[i] != 0 ||
3883 !(slice_size_vec[i] == -1 ||
3884 slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3885 // Not slicing on the same axis as the Pack op.
3886 return Status::OK();
3887 }
3888 }
3889 }
3890
3891 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3892 return errors::InvalidArgument(
3893 "Node ", node->name(), " requested invalid slice index ",
3894 *slice_start_value, " on axis ", pack_axis,
3895 " from tensor of shape: ", pack_output_shape.DebugString());
3896 }
3897
3898 *found = true; // slice_start_value is valid.
3899 return Status::OK();
3900 }
3901
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3902 Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3903 const PartialTensorShape& pack_output_shape,
3904 int pack_axis, int64* slice_start_value,
3905 bool* found, bool* must_expand_dims) {
3906 TF_RETURN_IF_ERROR(
3907 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3908 "new_axis_mask", "shrink_axis_mask"}));
3909
3910 const int begin_mask = node->attr().at("begin_mask").i();
3911 const int end_mask = node->attr().at("end_mask").i();
3912 const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3913 const int new_axis_mask = node->attr().at("new_axis_mask").i();
3914 const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3915
3916 // Check that the StridedSlice is one of these at pack_axis:
3917 // [..., i, ...]
3918 // [..., i:i+1, ...]
3919 // [..., :1, ...]
3920 // [..., -1:, ...]
3921 /// [..., s_{pack_axis}-1:, ...]
3922 NodeDef* slice_begin;
3923 NodeDef* slice_end;
3924 NodeDef* slice_strides;
3925 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3926 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3927 TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3928
3929 for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3930 if (!IsReallyConstant(*n)) return Status::OK();
3931 }
3932
3933 Tensor slice_begin_t;
3934 Tensor slice_end_t;
3935 Tensor slice_strides_t;
3936
3937 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3938 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3939 return Status::OK();
3940 }
3941 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3942 if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3943 return Status::OK();
3944 }
3945 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3946 if (!slice_strides_t.FromProto(
3947 slice_strides->attr().at("value").tensor())) {
3948 return Status::OK();
3949 }
3950 TensorShape processing_shape;
3951 TensorShape final_shape;
3952 bool is_identity;
3953 bool is_simple_slice;
3954 bool slice_dim0;
3955 gtl::InlinedVector<int64, 4> slice_begin_vec;
3956 gtl::InlinedVector<int64, 4> slice_end_vec;
3957 gtl::InlinedVector<int64, 4> slice_strides_vec;
3958 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3959 &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3960 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3961 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3962 &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3963
3964 if (!is_simple_slice) return Status::OK();
3965
3966 int begin_index = -1;
3967 int64_t begin_value = 0;
3968 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3969 const int64_t v = slice_begin_vec[i];
3970 if (v != 0) {
3971 if (begin_index != -1) {
3972 // At least two start values that are nonzero.
3973 return Status::OK();
3974 }
3975 begin_index = i;
3976 begin_value = v;
3977 }
3978 }
3979
3980 int end_index = -1;
3981 int64_t end_value = 0;
3982 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3983 const int64_t v = slice_end_vec[i];
3984 if (v != pack_output_shape.dim_size(i)) {
3985 if (end_index != -1) {
3986 // At least two end values that are nonzero.
3987 return Status::OK();
3988 }
3989 end_index = i;
3990 end_value = v;
3991 }
3992 }
3993
3994 if (begin_index == -1 && end_index == -1) return Status::OK();
3995 if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3996 // Somehow received different axes for begin/end slicing
3997 return Status::OK();
3998 }
3999 const int slice_axis = (begin_index == -1) ? end_index : begin_index;
4000 if (slice_axis != pack_axis) {
4001 // Not slicing on the same axis as the Pack op.
4002 return Status::OK();
4003 }
4004 *slice_start_value = (begin_index == -1) ? 0 : begin_value;
4005 const int64_t slice_end_value =
4006 (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
4007 if (slice_end_value != *slice_start_value + 1) {
4008 // Not slicing a single value out.
4009 return Status::OK();
4010 }
4011
4012 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
4013 return errors::InvalidArgument(
4014 "Node ", node->name(), " requested invalid slice index ",
4015 *slice_start_value, " on axis ", slice_axis,
4016 " from tensor of shape: ", pack_output_shape.DebugString());
4017 }
4018
4019 if (shrink_axis_mask == 0) {
4020 *must_expand_dims = true;
4021 } else if (shrink_axis_mask == (1 << slice_axis)) {
4022 *must_expand_dims = false;
4023 } else {
4024 // Shrinking on a different axis from the one that we are slicing on.
4025 return Status::OK();
4026 }
4027
4028 *found = true; // slice_start_value is valid.
4029 return Status::OK();
4030 }
4031
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64_t slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)4032 Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
4033 int64_t slice_start_value, int pack_axis,
4034 bool must_expand_dims, string* simplified_node_name) {
4035 const string& input_slice = pack->input(slice_start_value);
4036
4037 const OpInfo::TensorProperties* input_slice_properties;
4038 TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
4039 &input_slice_properties));
4040 PartialTensorShape input_slice_shape(input_slice_properties->shape());
4041
4042 const OpInfo::TensorProperties* output_properties;
4043 TF_RETURN_IF_ERROR(GetTensorProperties(
4044 strings::StrCat(node->name(), ":", 0), &output_properties));
4045 PartialTensorShape output_shape(output_properties->shape());
4046 NodeDef* output =
4047 AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
4048 if (!must_expand_dims) {
4049 output->set_op("Identity");
4050 output->set_device(node->device());
4051 SetDataTypeToAttr(output_properties->dtype(), "T", output);
4052 output->add_input(input_slice);
4053 } else {
4054 NodeDef* axis = AddEmptyNode(
4055 OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
4056 axis->set_op("Const");
4057 axis->set_device(node->device());
4058 // We need to add a control edge from input slice to guarantee that axis
4059 // constant will be executed in the same frame as `input_slice`, otherwise
4060 // ExpandDims might have mismatched input frames.
4061 axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
4062 auto axis_attr = axis->mutable_attr();
4063 SetDataTypeToAttr(DT_INT32, "dtype", axis);
4064 auto* axis_t = (*axis_attr)["value"].mutable_tensor();
4065 axis_t->set_dtype(DT_INT32);
4066 axis_t->add_int_val(pack_axis);
4067 AddToOptimizationQueue(axis);
4068 output->set_op("ExpandDims");
4069 output->set_device(node->device());
4070 SetDataTypeToAttr(output_properties->dtype(), "T", output);
4071 SetDataTypeToAttr(DT_INT32, "Tdim", output);
4072 output->add_input(input_slice);
4073 output->add_input(axis->name());
4074 }
4075
4076 // Copy dependencies over.
4077 ForwardControlDependencies(output, {node, pack});
4078 AddToOptimizationQueue(output);
4079 *simplified_node_name = output->name();
4080
4081 return Status::OK();
4082 }
4083 };
4084
4085 // Eliminates unnecessary copies during sparse embedding lookup operations.
4086 //
4087 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
4088 // generates code of the form:
4089 //
4090 // embeddings = <a 2D Tensor>
4091 // sparse_ids = <a tf.int64 SparseTensor>
4092 // segment_ids = sparse_ids.indices[:, 0]
4093 // ids, idx = tf.unique(sparse_ids.values)
4094 // gathered_rows = tf.gather(params, ids)
4095 // result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
4096 //
4097 // In this case, all of the work in `tf.unique()` and `tf.gather()`
4098 // can be avoided by passing the full embeddings to
4099 // `tf.sparse.segment_<combiner>()` and performing the same amount of
4100 // computation (but fewer copies and allocations) as follows:
4101 //
4102 // embeddings = <a 2D Tensor>
4103 // sparse_ids = <a tf.int64 SparseTensor>
4104 // segment_ids = sparse_ids.indices[:, 0]
4105 // result = tf.sparse.segment_<combiner>(
4106 // embeddings, sparse_ids.values, segment_ids)
4107 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
4108 public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4109 explicit SimplifyEmbeddingLookupStage(
4110 const GraphOptimizerContext& ctx,
4111 const ArithmeticOptimizerContext& ctx_ext)
4112 : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
4113 }
4114 ~SimplifyEmbeddingLookupStage() override = default;
4115
IsSupported(const NodeDef * node) const4116 bool IsSupported(const NodeDef* node) const override {
4117 return IsAnySparseSegmentReduction(*node);
4118 }
4119
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4120 Status TrySimplify(NodeDef* reduction_node,
4121 string* simplified_node_name) override {
4122 if (IsInPreserveSet(*reduction_node)) return Status::OK();
4123
4124 // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
4125 // axis.
4126 NodeDef* gather_node = nullptr;
4127 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
4128 if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
4129 gather_node->device() != reduction_node->device())
4130 return Status::OK();
4131 if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
4132 return Status::OK();
4133
4134 // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
4135 // axis.
4136 NodeDef* unique_node = nullptr;
4137 TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
4138 if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
4139 unique_node->device() != gather_node->device())
4140 return Status::OK();
4141 if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
4142 return Status::OK();
4143
4144 DataType unique_element_type;
4145 TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
4146
4147 // Input 1 (indices) of the reduction node must be output 1 of the unique
4148 // node.
4149 const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
4150 if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
4151
4152 // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
4153 // node.
4154 reduction_node->set_input(1, unique_node->input(0));
4155 ctx().node_map->UpdateInput(reduction_node->name(),
4156 reduction_node->input(1),
4157 unique_node->input(0));
4158 SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
4159
4160 // Input 0 (data) of the reduction node becomes input 1 (params) of the
4161 // gather node.
4162 const OpInfo::TensorProperties* gather_input_properties;
4163 TF_RETURN_IF_ERROR(
4164 GetTensorProperties(gather_node->input(0), &gather_input_properties));
4165 if (gather_input_properties->dtype() == DT_RESOURCE) {
4166 // If the input is a ResourceGather, we need to add
4167 // ReadVariableOp.
4168 NodeDef* variable_node = nullptr;
4169 TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(0), &variable_node));
4170 NodeDef* read_var_node = ctx().optimized_graph->add_node();
4171 read_var_node->set_name(OptimizedNodeName(
4172 ParseNodeScopeAndName(reduction_node->name()), "ReadVar"));
4173 read_var_node->set_op("ReadVariableOp");
4174 read_var_node->add_input(gather_node->input(0));
4175 read_var_node->set_device(variable_node->device());
4176
4177 // The Variable and the Gather node should have the same
4178 // dtype, but it might not be set on both nodes.
4179 auto attr = read_var_node->mutable_attr();
4180 if (variable_node->attr().count("dtype")) {
4181 SetAttrValue(variable_node->attr().at("dtype").type(),
4182 &(*attr)["dtype"]);
4183 }
4184 if (gather_node->attr().count("dtype")) {
4185 SetAttrValue(gather_node->attr().at("dtype").type(), &(*attr)["dtype"]);
4186 }
4187 // Copy the _class attr from the Gather node should it exist in case
4188 // of location constraints with the variable.
4189 if (gather_node->attr().count("_class")) {
4190 (*attr)["_class"] = gather_node->attr().at("_class");
4191 }
4192 if (variable_node->attr().count("shape")) {
4193 SetAttrValue(variable_node->attr().at("shape").shape(),
4194 &(*attr)["_output_shapes"]);
4195 }
4196
4197 ctx().node_map->AddNode(read_var_node->name(), read_var_node);
4198 reduction_node->set_input(0, read_var_node->name());
4199 ctx().node_map->UpdateInput(reduction_node->name(),
4200 reduction_node->input(0),
4201 read_var_node->name());
4202 } else {
4203 reduction_node->set_input(0, gather_node->input(0));
4204 ctx().node_map->UpdateInput(reduction_node->name(),
4205 reduction_node->input(0),
4206 gather_node->input(0));
4207 }
4208 *simplified_node_name = reduction_node->name();
4209 return Status::OK();
4210 }
4211
4212 private:
IsAxis0(const NodeDef & node,int axis_input)4213 bool IsAxis0(const NodeDef& node, int axis_input) {
4214 Tensor axis_tensor;
4215 if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
4216 return false;
4217 if (axis_tensor.NumElements() != 1) return false;
4218 if (axis_tensor.dtype() == DT_INT32) {
4219 return axis_tensor.flat<int32>()(0) == 0;
4220 } else if (axis_tensor.dtype() == DT_INT64) {
4221 return axis_tensor.flat<int64>()(0) == 0;
4222 } else {
4223 return false;
4224 }
4225 }
4226 };
4227
4228 // Eliminates unnecessary casts before sparse segment reduction operations.
4229 //
4230 // Existing graphs and library code would often insert a cast from DT_INT64 to
4231 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
4232 // Support for for DT_INT64 indices and/or segment_ids now exists, so we can
4233 // pass the input directly without a cast.
4234 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
4235 public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4236 explicit RemoveCastIntoSegmentReductionStage(
4237 const GraphOptimizerContext& ctx,
4238 const ArithmeticOptimizerContext& ctx_ext)
4239 : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
4240 ctx_ext) {}
4241 ~RemoveCastIntoSegmentReductionStage() override = default;
4242
IsSupported(const NodeDef * node) const4243 bool IsSupported(const NodeDef* node) const override {
4244 return IsAnySparseSegmentReduction(*node);
4245 }
4246
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4247 Status TrySimplify(NodeDef* reduction_node,
4248 string* simplified_node_name) override {
4249 if (IsInPreserveSet(*reduction_node)) return Status::OK();
4250
4251 bool optimized = false;
4252
4253 // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
4254 // DT_INT64.
4255 std::array<std::pair<int, string>, 2> input_details = {
4256 std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
4257
4258 for (const auto& input : input_details) {
4259 int input_index = input.first;
4260 const string& type_attr_name = input.second;
4261 NodeDef* cast_node = nullptr;
4262 TF_RETURN_IF_ERROR(
4263 GetInputNode(reduction_node->input(input_index), &cast_node));
4264 DataType original_index_type;
4265 if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
4266 reduction_node->set_input(input_index, cast_node->input(0));
4267 ctx().node_map->UpdateInput(reduction_node->name(),
4268 reduction_node->input(1),
4269 cast_node->input(0));
4270 SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
4271 optimized = true;
4272 }
4273 }
4274
4275 if (optimized) *simplified_node_name = reduction_node->name();
4276 return Status::OK();
4277 }
4278
4279 private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)4280 bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
4281 if (!IsCast(node)) return false;
4282 if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
4283 return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
4284 }
4285 };
4286
4287 } // namespace
4288
SimplifyArithmeticOps(bool can_use_shapes)4289 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
4290 SetVector<NodeDef*> nodes_to_simplify;
4291 nodes_to_simplify.Reserve(optimized_graph_->node_size());
4292 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
4293 nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
4294 }
4295
4296 const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
4297 graph_properties_.get(), node_map_.get(),
4298 &feed_nodes_, opt_level_);
4299 const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
4300
4301 // Stop pipeline after first stage returning non-empty simplified tensor
4302 // name.
4303 const auto stop = [](const string& result) { return !result.empty(); };
4304 GraphOptimizerStagePipeline<string> pipeline(stop);
4305 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
4306
4307 if (options_.combine_add_to_addn && can_use_shapes)
4308 pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
4309 if (options_.fold_conjugate_into_transpose)
4310 pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
4311 if (options_.fold_multiply_into_conv)
4312 pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
4313 if (options_.fold_transpose_into_matmul)
4314 pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
4315 if (is_aggressive && options_.hoist_common_factor_out_of_aggregation &&
4316 can_use_shapes)
4317 pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
4318 if (options_.minimize_broadcasts && can_use_shapes)
4319 pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
4320 if (options_.remove_identity_transpose && can_use_shapes)
4321 pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
4322 if (options_.remove_involution)
4323 pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
4324 if (options_.remove_redundant_bitcast)
4325 pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
4326 if (options_.remove_redundant_cast)
4327 pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
4328 if (options_.replace_pack_with_tile_reshape)
4329 pipeline.AddStage<ReplacePackWithTileReshape>(ctx, ctx_ext);
4330 if (options_.replace_mul_with_tile && can_use_shapes)
4331 pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
4332 if (options_.reduce_upsampling_dims)
4333 pipeline.AddStage<ReduceUpsamplingDims>(ctx, ctx_ext);
4334 if (options_.remove_redundant_reshape && can_use_shapes)
4335 pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
4336 if (options_.remove_negation)
4337 pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
4338 if (options_.replace_mul_with_square)
4339 pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
4340 if (options_.remove_logical_not)
4341 pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
4342 if (options_.reorder_cast_like_and_value_preserving)
4343 pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
4344 if (options_.simplify_aggregation)
4345 pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
4346 if (options_.hoist_cwise_unary_chains)
4347 pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
4348 if (options_.convert_sqrt_div_to_rsqrt_mul)
4349 pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
4350 if (options_.remove_idempotent)
4351 pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
4352 if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
4353 if (options_.convert_log1p)
4354 pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
4355 if (options_.convert_log_softmax)
4356 pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
4357 if (options_.optimize_max_or_min_of_monotonic)
4358 pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
4359 if (options_.convert_expm1)
4360 pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
4361 if (options_.unary_ops_composition)
4362 pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
4363 if (options_.remove_stack_slice_same_axis)
4364 pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
4365 if (options_.simplify_embedding_lookup)
4366 pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
4367 if (options_.remove_cast_into_segment_reduction)
4368 pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
4369 if (options_.fuse_squared_diff)
4370 pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
4371
4372 VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
4373 << absl::StrJoin(pipeline.StageNames(), ", ");
4374
4375 while (!nodes_to_simplify.Empty()) {
4376 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4377 NodeDef* node = nodes_to_simplify.PopBack();
4378
4379 string simplified_tensor = "";
4380 bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
4381
4382 // If the node was not optimized by any of the stages, go to the next one.
4383 if (!optimized) continue;
4384
4385 // re-wire consumers of an old node to the new one
4386 if (NodeName(simplified_tensor) != node->name()) {
4387 // Always consider simplified_tensor for further optimizations.
4388 NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
4389 if (simplified_node != nullptr) {
4390 nodes_to_simplify.PushBack(simplified_node);
4391 }
4392 // When `node` is simplified to another node rather than in-place, the
4393 // consumers of `node` are already redirected to `simplified_tensor`.
4394 // Re-push the consumers into `nodes_to_simplify` for further
4395 // optimizations.
4396 const std::vector<NodeDef*> consumers =
4397 node_map_->GetOutputsOrderedByNodeName(node->name());
4398 for (NodeDef* consumer : consumers) {
4399 // Update `consumer`'s use of `node` to `input`'s operand.
4400 for (int i = 0; i < consumer->input_size(); ++i) {
4401 int operand_pos;
4402 string operand_node_name =
4403 ParseNodeName(consumer->input(i), &operand_pos);
4404 if (operand_node_name == node->name()) {
4405 *consumer->mutable_input(i) =
4406 (operand_pos < 0
4407 ? AsControlDependency(NodeName(simplified_tensor))
4408 : simplified_tensor);
4409 }
4410 }
4411 node_map_->UpdateInput(consumer->name(), node->name(),
4412 simplified_tensor);
4413 nodes_to_simplify.PushBack(consumer);
4414 }
4415 }
4416 }
4417 return Status::OK();
4418 }
4419
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)4420 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
4421 const GrapplerItem& item,
4422 GraphDef* optimized_graph) {
4423 // Set up helper data structures.
4424 nodes_to_preserve_ = item.NodesToPreserve();
4425 fetch_nodes_known_ = !item.fetch.empty();
4426 GrapplerItem optimized_item(item);
4427 optimized_graph_ = &optimized_item.graph;
4428
4429 node_map_.reset(new NodeMap(optimized_graph_));
4430 for (const auto& feed : item.feed) {
4431 feed_nodes_.insert(NodeName(feed.first));
4432 }
4433
4434 // // Disable restricted graph rewrites.
4435 options_.unary_ops_composition &=
4436 item.optimization_options().allow_non_differentiable_rewrites;
4437
4438 // Perform topological sort on the graph in order to help DedupComputations
4439 // and AddOpsRewrite to optimize larger subgraphs starting from the roots
4440 // with more inputs.
4441 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
4442 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4443
4444 graph_properties_.reset(new GraphProperties(optimized_item));
4445 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4446 const Status status =
4447 graph_properties_->InferStatically(assume_valid_feeds,
4448 /*aggressive_shape_inference=*/false,
4449 /*include_tensor_values=*/false);
4450 const bool can_use_shapes = status.ok();
4451 if (!can_use_shapes) {
4452 VLOG(1) << "Shape inference failed." << status.error_message();
4453 }
4454
4455 // Perform the optimizations.
4456 TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
4457 *optimized_graph = std::move(*optimized_graph_);
4458 return Status::OK();
4459 }
4460
4461 } // namespace grappler
4462 } // namespace tensorflow
4463