1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/graph_topology_view.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
42 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
43 #include "tensorflow/core/grappler/utils.h"
44 #include "tensorflow/core/grappler/utils/canonicalizer.h"
45 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
46 #include "tensorflow/core/grappler/utils/topological_sort.h"
47 #include "tensorflow/core/grappler/utils/traversal.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/hash/hash.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/tensor_coding.h"
56 #include "tensorflow/core/protobuf/error_codes.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/saved_tensor_slice_util.h"
59 #include "tensorflow/core/util/strided_slice_op.h"
60
61 using tensorflow::strings::StrCat;
62
63 namespace tensorflow {
64 namespace grappler {
65 namespace {
66
67 // Mark nodes created or optimized by a stage with a tag.
68 constexpr char kAddOpsRewriteTag[] =
69 "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
70 constexpr char kMinimizeBroadcastsTag[] =
71 "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
72
73 // Extract values from a Const op to `values`. Returns true if succeeds.
74 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)75 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
76 if (node.op() != "Const") {
77 return false;
78 }
79
80 if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
81 node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
82 return false;
83 }
84
85 // TensorProto represents the content of the tensor in either <type>_val or
86 // tensor_content.
87 const TensorProto& tensor = node.attr().at("value").tensor();
88 typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
89 checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
90
91 if (!tensor_values->empty() && tensor.has_tensor_shape()) {
92 // When tensor_shape is set, theoretically the representation of the data
93 // could be compressed. So, before copying values to the returned vector,
94 // make sure no compression happens.
95 const TensorShapeProto& shape = tensor.tensor_shape();
96 if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
97 values->insert(values->end(), tensor_values->begin(),
98 tensor_values->end());
99 return true;
100 }
101 }
102
103 const auto tensor_content_size = tensor.tensor_content().size();
104 if (tensor_content_size > 0) {
105 CHECK_EQ(0, tensor_content_size % sizeof(T))
106 << "tensor_content_size (" << tensor_content_size
107 << ") is not a multiple of " << sizeof(T);
108 values->resize(tensor_content_size / sizeof(T));
109 port::CopyToArray(tensor.tensor_content(),
110 reinterpret_cast<char*>(values->data()));
111 return true;
112 }
113
114 return false;
115 }
116
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)117 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
118 GraphDef* graph, NodeMap* node_map) {
119 bool already_exists = false;
120 for (const string& input : node->input()) {
121 if (input == new_input || AsControlDependency(input) == new_input) {
122 already_exists = true;
123 break;
124 }
125 }
126 if (!already_exists) {
127 const string ctrl_dep =
128 ConstantFolding::AddControlDependency(new_input, graph, node_map);
129 node->add_input(ctrl_dep);
130 node_map->AddOutput(NodeName(new_input), node->name());
131 }
132 return !already_exists;
133 }
134
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)135 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
136 (*node->mutable_attr())[attr_name].set_type(dtype);
137 }
138
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)139 NodeDef* GetTailOfValuePreservingChain(
140 const NodeDef& node, const NodeMap& node_map,
141 const std::unordered_set<string>& nodes_to_preserve) {
142 auto is_value_preserving_non_branching = [&](const NodeDef& node) {
143 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
144 IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
145 };
146 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
147 is_value_preserving_non_branching);
148 }
149
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)150 NodeDef* GetTailOfIdempotentChain(
151 const NodeDef& node, const NodeMap& node_map,
152 const std::unordered_set<string>& nodes_to_preserve) {
153 auto is_idempotent_non_branching = [&](const NodeDef& node) {
154 return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
155 IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
156 };
157 return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
158 is_idempotent_non_branching);
159 }
160
161 // GetElementUnexhaustive tries to get the value of an element in a tensor and
162 // turn it into complex128 type. It only check for a limited number of data
163 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)164 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
165 complex128* element) {
166 if (dtypes.find(t.dtype()) == dtypes.end()) return false;
167 switch (t.dtype()) {
168 case DT_BFLOAT16:
169 *element = complex128(t.flat<bfloat16>()(i));
170 return true;
171 case DT_HALF:
172 *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
173 return true;
174 case DT_INT32:
175 *element = complex128(t.flat<int32>()(i));
176 return true;
177 case DT_INT64:
178 *element = complex128(t.flat<int64_t>()(i));
179 return true;
180 case DT_FLOAT:
181 *element = complex128(t.flat<float>()(i));
182 return true;
183 case DT_DOUBLE:
184 *element = complex128(t.flat<double>()(i));
185 return true;
186 case DT_COMPLEX64:
187 *element = complex128(t.flat<complex64>()(i));
188 return true;
189 case DT_COMPLEX128:
190 *element = t.flat<complex128>()(i);
191 return true;
192 default:
193 return false;
194 }
195 }
196
NodeIsOnCpu(const NodeDef & node)197 bool NodeIsOnCpu(const NodeDef& node) {
198 string task;
199 string device;
200 return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
201 absl::StrContains(device, DEVICE_CPU);
202 }
203
204 // True if all regular (non-control) inputs reference the same node or if there
205 // are no non-control inputs
AllRegularInputsEqual(const NodeDef & node)206 bool AllRegularInputsEqual(const NodeDef& node) {
207 if (!HasRegularInputs(node)) return true;
208 for (int i = 1; i < node.input_size(); ++i) {
209 if (IsControlInput(node.input(i))) {
210 break;
211 }
212 if (node.input(0) != node.input(i)) {
213 return false;
214 }
215 }
216 return true;
217 }
218
219 // Replace a node with NoOp and reset shape inference results for it..
ReplaceWithNoOp(NodeDef * node,const GraphOptimizerContext & ctx)220 void ReplaceWithNoOp(NodeDef* node, const GraphOptimizerContext& ctx) {
221 ctx.node_map->RemoveInputs(node->name());
222 ctx.graph_properties->ClearInputProperties(node->name());
223 ctx.graph_properties->ClearOutputProperties(node->name());
224 EraseRegularNodeAttributes(node);
225 node->set_op("NoOp");
226 node->clear_input();
227 }
228
229 // Graph optimizer context extension specific to ArithmeticOptimizer.
230 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon2c8062030111::ArithmeticOptimizerContext231 explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
232 : nodes_to_simplify(nodes_to_simplify) {}
233 SetVector<NodeDef*>* nodes_to_simplify;
234 };
235
236 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
237 // AddOps optimization, etc...
238 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
239 public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)240 explicit ArithmeticOptimizerStage(const string& name,
241 const GraphOptimizerContext& ctx,
242 const ArithmeticOptimizerContext ctx_ext)
243 : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
244 ctx_ext_(ctx_ext) {}
245 ~ArithmeticOptimizerStage() override = default;
246
247 protected:
248 // Simplification graph rewrite can create additional nodes that are inputs
249 // to final simplified node, they can be also added to the arithmetic
250 // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)251 void AddToOptimizationQueue(NodeDef* node) {
252 ctx_ext_.nodes_to_simplify->PushBack(node);
253 }
254
255 // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)256 Status UpdateConsumers(NodeDef* node, const string& new_input) {
257 const auto consumers = ctx().node_map->GetOutputs(node->name());
258 if (consumers.empty()) return OkStatus();
259 const TensorId new_tensor = ParseTensorName(new_input);
260 for (NodeDef* consumer : consumers) {
261 if (consumer->name() == new_tensor.node()) continue;
262 bool updated = false;
263 for (int i = 0; i < consumer->input_size(); ++i) {
264 const TensorId input_tensor = ParseTensorName(consumer->input(i));
265 if (input_tensor.node() == node->name()) {
266 if (new_tensor.index() < 0 && input_tensor.index() >= 0) {
267 // Overwriting a data input with a control input will make the graph
268 // invalid.
269 return errors::InvalidArgument(
270 "Cannot override data input ", input_tensor.ToString(),
271 " with control input ", new_tensor.ToString());
272 }
273 consumer->set_input(i, input_tensor.index() < 0
274 ? absl::StrCat("^", new_tensor.node())
275 : new_input);
276 ctx().node_map->UpdateInput(consumer->name(), node->name(),
277 new_input);
278 updated = true;
279 }
280 }
281 if (updated) {
282 DedupControlInputs(consumer);
283 AddToOptimizationQueue(consumer);
284 }
285 }
286 return OkStatus();
287 }
288
289 // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
290 // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)291 void ForwardControlDependencies(
292 NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
293 for (const auto& src : src_nodes) {
294 for (int i = src->input_size() - 1; i >= 0; --i) {
295 if (IsControlInput(src->input(i))) {
296 *target_node->add_input() = src->input(i);
297 ctx().node_map->AddOutput(NodeName(src->input(i)),
298 target_node->name());
299 } else {
300 break;
301 }
302 }
303 }
304 DedupControlInputs(target_node);
305 }
306
IsReallyConstant(const NodeDef & node) const307 bool IsReallyConstant(const NodeDef& node) const {
308 if (!IsConstant(node)) {
309 return false;
310 }
311 // If the node is fed it's not constant anymore.
312 return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
313 }
314
IsInPreserveSet(const NodeDef & node) const315 bool IsInPreserveSet(const NodeDef& node) const {
316 return ctx().nodes_to_preserve->find(node.name()) !=
317 ctx().nodes_to_preserve->end();
318 }
319
320 // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const321 bool IsDrivenByControlDependency(const NodeDef& node) const {
322 return std::any_of(
323 node.input().begin(), node.input().end(),
324 [](const string& input) { return IsControlInput(input); });
325 }
326
327 // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const328 bool DrivesControlDependency(const NodeDef& node) const {
329 for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
330 for (int i = 0; i < output->input_size(); ++i) {
331 const TensorId tensor = ParseTensorName(output->input(i));
332 if (tensor.node() == node.name() && tensor.index() < 0) {
333 return true;
334 }
335 }
336 }
337 return false;
338 }
339
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)340 bool GetTensorFromConstNode(const string& node_name_or_input,
341 Tensor* tensor) {
342 const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
343 return node != nullptr && IsReallyConstant(*node) &&
344 CheckAttrExists(*node, "value").ok() &&
345 tensor->FromProto(node->attr().at("value").tensor());
346 }
347
348 private:
349 // Extended context required for ArithmeticOptimizer.
350 const ArithmeticOptimizerContext ctx_ext_;
351 };
352
353 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
354 // group of nodes from the optimized graph.
355 //
356 // * AddOpsRewrite:
357 // Rewrite a group of Add/AddN with compact Add/AddN tree
358 //
359 // * MinimizeBroadcasts:
360 // Rewrite a group of binary associative ops, reordering
361 // inputs, to minimize the cost of broadcast
362 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
363 public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)364 explicit ArithmeticNodesGroupOptimizerStage(
365 const string& name, const GraphOptimizerContext& ctx,
366 const ArithmeticOptimizerContext ctx_ext)
367 : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
368 ~ArithmeticNodesGroupOptimizerStage() override = default;
369
370 // Input name with a statically inferred shape from GraphProperties
371 struct InputAndShape {
InputAndShapetensorflow::grappler::__anon2c8062030111::ArithmeticNodesGroupOptimizerStage::InputAndShape372 InputAndShape(const string& input, const TensorShapeProto& shape)
373 : input(input), shape(shape) {}
374 string input;
375 TensorShapeProto shape;
376 };
377
378 // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
379 // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
380 // obtained by graph traversal, starting from a root node.
381 struct OptimizedNodesGroup {
382 NodeDef* root_node;
383 TensorShapeProto root_shape;
384 // Optimized nodes that will be updated or removed by rewrite
385 std::vector<NodeDef*> optimized_nodes;
386 // Inputs to optimized nodes
387 std::vector<InputAndShape> inputs;
388 };
389
TrySimplify(NodeDef * node,string * simplified_node_name)390 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
391 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
392
393 OptimizedNodesGroup group;
394 TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
395
396 if (!group.optimized_nodes.empty()) {
397 *simplified_node_name = RewriteOptimizedNodesGroup(group);
398 }
399
400 return OkStatus();
401 }
402
403 protected:
404 // Modify the optimized graph after nodes group was successfully identified
405 virtual string RewriteOptimizedNodesGroup(
406 const OptimizedNodesGroup& group) = 0;
407
408 // Check if input can become a part of current optimized nodes group.
409 virtual bool IsAbsorbableByOptimizedNodesGroup(
410 const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
411
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const412 Status AbsorbInputByOptimizedNodesGroup(const string& input,
413 OptimizedNodesGroup* group) const {
414 std::deque<const string*> input_tensors;
415 input_tensors.push_front(&input);
416
417 while (!input_tensors.empty()) {
418 const string* input_tensor = input_tensors.front();
419 input_tensors.pop_front();
420
421 // Get a node for the input tensor.
422 NodeDef* input_node;
423 TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
424
425 if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
426 group->optimized_nodes.push_back(input_node);
427 for (int i = input_node->input_size() - 1; i >= 0; --i) {
428 const string& absorbed_node_input = input_node->input(i);
429 // TODO(ezhulenev): support control inputs
430 if (IsControlInput(absorbed_node_input)) continue;
431 input_tensors.push_front(&absorbed_node_input);
432 }
433 } else {
434 // If input node can't be absorbed, add it to OptimizedNodesGroup input.
435 const OpInfo::TensorProperties* properties;
436 TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
437 group->inputs.emplace_back(*input_tensor, properties->shape());
438 }
439 }
440
441 return OkStatus();
442 }
443
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const444 Status CreateOptimizedNodesGroup(NodeDef* root_node,
445 OptimizedNodesGroup* group) const {
446 const OpInfo::TensorProperties* root_node_output_properties;
447 TF_RETURN_IF_ERROR(
448 GetTensorProperties(root_node->name(), &root_node_output_properties));
449
450 group->root_node = root_node;
451 group->root_shape = root_node_output_properties->shape();
452
453 group->optimized_nodes.reserve(root_node->input_size());
454 for (int i = 0; i < root_node->input_size(); ++i) {
455 const string& input_i = root_node->input(i);
456 // TODO(ezhulenev): add support for control inputs
457 if (IsControlInput(input_i)) continue;
458 TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
459 }
460
461 return OkStatus();
462 }
463
464 // Check if all inputs can be broadcasted to the same shape
465 // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const466 bool HasAllInputsBroadcastableToShape(
467 const NodeDef& node, const OpInfo::TensorProperties& properties) const {
468 auto is_broadcastable = [this, &properties](const string& input) {
469 const OpInfo::TensorProperties* input_props;
470 Status has_input_properties = GetTensorProperties(input, &input_props);
471 return has_input_properties.ok() &&
472 ShapesBroadcastable(properties, *input_props);
473 };
474 return std::all_of(node.input().begin(), node.input().end(),
475 is_broadcastable);
476 }
477
ShapeSignature(const TensorShapeProto & shape) const478 string ShapeSignature(const TensorShapeProto& shape) const {
479 string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
480 for (int i = 0; i < shape.dim_size(); ++i)
481 strings::StrAppend(&signature, ":", shape.dim(i).size());
482 return signature;
483 }
484
MarkWithTag(const StringPiece tag,NodeDef * node)485 void MarkWithTag(const StringPiece tag, NodeDef* node) {
486 AddNodeAttr(tag, true, node);
487 }
488
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const489 void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
490 const StringPiece tag) const {
491 AddNodeAttr(tag, true, group.root_node);
492 for (NodeDef* optimized_node : group.optimized_nodes) {
493 AddNodeAttr(tag, true, optimized_node);
494 }
495 }
496
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const497 bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
498 const NodeDef& node) const {
499 return group.root_node->device() == node.device();
500 }
501
IsInPreserveSet(const NodeDef & node) const502 bool IsInPreserveSet(const NodeDef& node) const {
503 return ctx().nodes_to_preserve->find(node.name()) !=
504 ctx().nodes_to_preserve->end();
505 }
506
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const507 bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
508 return HasNodeAttr(node, tag);
509 }
510
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const511 bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
512 const StringPiece tag2) const {
513 return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
514 }
515 };
516
517 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
518 // original inputs of absorbed nodes.
519 //
520 // 1) All nodes must have the same device placement.
521 //
522 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
523 // optimized to a single AddN node.
524 //
525 // AddN_1
526 // / | \
527 // Add_1 z Add_2 -> AddN(x, y, z, w, q, e)
528 // / \ / \
529 // x y w Add_3
530 // / \
531 // q e
532 //
533 // 3) If some nodes have different shape (it needs to be broadcastable to the
534 // shape of a "root), tree is optimized to AddNs for symbolically equal
535 // shapes, and a tree of Add ops, that minimize broadcasts.
536 //
537 // AddN_1 Add
538 // / | \ / \
539 // Add_1 z Add_2 -> Add w
540 // / \ / \ / \
541 // x y w Add_3 AddN(x, y, q, e) z
542 // / \
543 // q e
544 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
545 public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)546 explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
547 const ArithmeticOptimizerContext& ctx_ext)
548 : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
549 ~AddOpsRewriteStage() override = default;
550
551 // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const552 bool IsSupported(const NodeDef* node) const override {
553 if (!CanOptimize(*node)) return false;
554
555 // shape must be symbolically defined and all inputs compatible with it
556 const OpInfo::TensorProperties* properties;
557 Status has_properties = GetTensorProperties(node->name(), &properties);
558 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
559 HasAllInputsBroadcastableToShape(*node, *properties);
560 }
561
562 protected:
563 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const564 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
565 const NodeDef& node) const override {
566 if (!CanOptimize(node)) return false;
567
568 if (!IsOnTheSameDevice(group, node)) {
569 return false;
570 }
571 // with a single output data consumer (presumably if we reach this node from
572 // previously absorbed or a root node, it means that this node is not used
573 // as an input to any other op, outside of the group)
574 if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
575 return false;
576 }
577 // All input shapes must be broadcastable to the node shape
578 const OpInfo::TensorProperties* properties;
579 Status has_properties = GetTensorProperties(node.name(), &properties);
580 return has_properties.ok() &&
581 HasAllInputsBroadcastableToShape(node, *properties);
582 }
583
584 // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const585 bool CanOptimize(const NodeDef& node) const {
586 // TODO(ezhulenev): check if AccumulateNV2 can be supported too
587 if (!IsAdd(node) && !IsAddN(node)) {
588 return false;
589 }
590 if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
591 return false;
592 }
593 // TODO(ezhulenev): relax this condition for root node
594 return !(IsDrivenByControlDependency(node) ||
595 DrivesControlDependency(node));
596 }
597
598 // Rewrite a group of add ops into a single AddN if all input shapes are
599 // symbolically equal. If not, create AddN for equal shapes first, and then
600 // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)601 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
602 VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
603 << " op=" << group.root_node->op()
604 << " num_optimized_nodes=" << group.optimized_nodes.size()
605 << " num_inputs=" << group.inputs.size();
606
607 // Do not optimize any of the nodes that are part of this group.
608 MarkAllMembersWithTag(group, kAddOpsRewriteTag);
609
610 // All new nodes will be placed under the scope of a root node.
611 auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
612
613 // Find what shapes are present in the inputs of absorbed nodes.
614 std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
615 for (const auto& input : group.inputs) {
616 shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
617 }
618
619 using SigKV = decltype(shape_sig_to_inputs)::value_type;
620 VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
621 << " unique shapes: "
622 << absl::StrJoin(shape_sig_to_inputs, ", ",
623 [](string* out, SigKV p) {
624 strings::StrAppend(out, p.first);
625 });
626
627 // Collect all the shapes from representative elements.
628 std::vector<TensorShapeProto> shapes;
629 shapes.reserve(shape_sig_to_inputs.size());
630 for (const auto& el : shape_sig_to_inputs)
631 shapes.push_back(el.second[0].shape);
632
633 // If all inputs have the same shape, rewrite whole group with a single AddN
634 if (shapes.size() == 1) {
635 string node_name = UniqueOptimizedNodeName(root_scope_and_name);
636 AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
637 group.inputs);
638 return node_name;
639 }
640
641 // For inputs of different shapes:
642 // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
643 // 2. Build a tree of Add nodes, minimizing cost of broadcast
644 std::sort(shapes.begin(), shapes.end(),
645 [](const TensorShapeProto& left, const TensorShapeProto& right) {
646 return CompareSymbolicallyShapedTensorSizes(left, right);
647 });
648
649 // optimized name for leaf AddN nodes
650 auto leaf_node_name = [&root_scope_and_name, this](int i) {
651 return UniqueOptimizedNodeName(root_scope_and_name,
652 strings::StrCat("Leaf_", i));
653 };
654 // optimized name for internal nodes of a tree built up from AddN leaves
655 auto internal_node_name = [&root_scope_and_name, this](int i) {
656 return UniqueOptimizedNodeName(root_scope_and_name,
657 strings::StrCat("Internal_", i));
658 };
659
660 // Add/AddN nodes that must be added to the tree
661 std::deque<InputAndShape> add_ops;
662
663 // Prepare leaf AddN nodes for inputs of equal shape
664 for (int i = 0, end = shapes.size(); i < end; ++i) {
665 const auto node_name = leaf_node_name(i);
666 const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
667 add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
668 node_name, inputs));
669 }
670
671 // Build up a tree of Add ops
672 int internal_nodes = 0;
673 do {
674 const InputAndShape lhs = add_ops.front();
675 add_ops.pop_front();
676 const InputAndShape rhs = add_ops.front();
677 add_ops.pop_front();
678 string name = add_ops.empty()
679 ? UniqueOptimizedNodeName(root_scope_and_name)
680 : internal_node_name(internal_nodes++);
681 InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
682 add_ops.push_front(add);
683 } while (add_ops.size() > 1);
684
685 InputAndShape optimized_root_node = add_ops.front();
686 return optimized_root_node.input;
687 }
688
689 // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)690 InputAndShape AddInputsOfSymbolicallyEqualShape(
691 const NodeDef& root_node, const string& node_name,
692 const std::vector<InputAndShape>& inputs) {
693 CHECK(!inputs.empty()) << "Inputs must be non-empty";
694
695 // Do not create redundant AddN nodes
696 if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
697 return inputs[0];
698 }
699
700 // get shape from representative element
701 auto shape = inputs[0].shape;
702
703 // copy attributes from a root node
704 DataType dtype = root_node.attr().at("T").type();
705
706 // add new AddN node
707 NodeDef* node = AddEmptyNode(node_name);
708 node->set_op("AddN");
709 node->set_device(root_node.device());
710 (*node->mutable_attr())["T"].set_type(dtype);
711 (*node->mutable_attr())["N"].set_i(inputs.size());
712
713 for (const auto& inputAndShape : inputs) {
714 ctx().node_map->AddOutput(inputAndShape.input, node_name);
715 node->add_input(inputAndShape.input);
716 }
717
718 MarkWithTag(kAddOpsRewriteTag, node);
719 return InputAndShape(node_name, shape);
720 }
721
722 // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)723 InputAndShape AddAggregatedInputs(const NodeDef& root_node,
724 const string& node_name,
725 const InputAndShape& left,
726 const InputAndShape& right) {
727 // copy attributes from a root node
728 DataType dtype = root_node.attr().at("T").type();
729
730 // add new Add node
731 NodeDef* node = AddEmptyNode(node_name);
732 node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
733 : "AddV2");
734 node->set_device(root_node.device());
735 (*node->mutable_attr())["T"].set_type(dtype);
736 node->add_input(left.input);
737 node->add_input(right.input);
738
739 ctx().node_map->AddOutput(left.input, node_name);
740 ctx().node_map->AddOutput(right.input, node_name);
741
742 MarkWithTag(kAddOpsRewriteTag, node);
743 return InputAndShape(
744 node_name, TensorShapeProto()); // shape is not important at this point
745 }
746 };
747
748 // Use the distributive property of multiplication and division over addition,
749 // along with commutativity of the former, to hoist common factors/denominators
750 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
751 // This pattern occurs frequently in regularization terms for the gradients
752 // during training.
753 //
754 // For example, we can rewrite an expression of the form:
755 // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
756 // to the following:
757 // Mul(x, AddN(y1, y2, y3, ... yn))
758 // For division, we can rewrite
759 // AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
760 // to:
761 // Div(AddN(y1, y2, y3, ... yn), x)
762 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
763 public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)764 explicit HoistCommonFactorOutOfAggregation(
765 const GraphOptimizerContext& ctx,
766 const ArithmeticOptimizerContext& ctx_ext)
767 : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
768 ~HoistCommonFactorOutOfAggregation() override = default;
769
IsSupported(const NodeDef * node) const770 bool IsSupported(const NodeDef* node) const override {
771 return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
772 !IsRewritten(node);
773 }
774
TrySimplify(NodeDef * node,string * simplified_node_name)775 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
776 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
777
778 bool common_factor_is_denominator = false;
779 std::set<string> common_factors;
780 std::vector<string> ctrl_deps;
781 TF_RETURN_IF_ERROR(GetCommonFactors(
782 node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
783
784 if (common_factors.size() == 1) {
785 const string& common_factor = *common_factors.begin();
786
787 // Gather up the non-shared factors
788 bool shapes_match = true;
789 std::vector<string> unique_factors;
790 TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
791 common_factor_is_denominator,
792 &shapes_match, &unique_factors));
793
794 if (shapes_match) {
795 NodeDef* input_0;
796 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
797
798 // Use a copy of the first node for the outer multiplication/division.
799 NodeDef* new_outer_node = AddCopyNode(
800 OuterNodeName(node, common_factor_is_denominator), input_0);
801 // And a copy of aggregation node as one of the inner operands
802 NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
803
804 new_outer_node->set_device(node->device());
805 if (common_factor_is_denominator) {
806 new_outer_node->set_input(0, new_add_node->name());
807 new_outer_node->set_input(1, common_factor);
808 } else {
809 new_outer_node->set_input(0, common_factor);
810 new_outer_node->set_input(1, new_add_node->name());
811 }
812
813 ctx().node_map->AddOutput(common_factor, new_outer_node->name());
814 ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
815
816 // Hoist non-shared factors up into the new AddN node.
817 for (int i = 0, end = unique_factors.size(); i < end; ++i) {
818 const string& unique_factor_i = unique_factors[i];
819 new_add_node->set_input(i, unique_factor_i);
820 ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
821 }
822
823 // Add control deps on add node
824 for (const string& ctrl_dep : ctrl_deps) {
825 *new_add_node->add_input() = ctrl_dep;
826 ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
827 }
828
829 // optimize new inner aggregation node
830 AddToOptimizationQueue(new_add_node);
831 // do not optimize the same node twice
832 rewritten_nodes_.insert(node->name());
833 *simplified_node_name = new_outer_node->name();
834 }
835 }
836 return OkStatus();
837 }
838
839 private:
840 // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const841 string OuterNodeName(const NodeDef* node, bool is_div) const {
842 auto scope_and_name = ParseNodeScopeAndName(node->name());
843 return is_div ? OptimizedNodeName(scope_and_name, "Div")
844 : OptimizedNodeName(scope_and_name, "Mul");
845 }
846
847 // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const848 string InnerAddNodeName(const NodeDef* node) const {
849 auto scope_and_name = ParseNodeScopeAndName(node->name());
850 return OptimizedNodeName(scope_and_name, "AddV2");
851 }
852
853 // Determine the set of common factors if the input nodes are all Mul or
854 // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const855 Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
856 bool* common_factor_is_denominator,
857 std::vector<string>* ctrl_deps) const {
858 CHECK(common_factors->empty());
859 CHECK_NOTNULL(common_factor_is_denominator);
860 *common_factor_is_denominator = false;
861
862 bool has_mul = false;
863 bool has_div = false;
864 for (int i = 0; i < node->input_size(); ++i) {
865 if (i > 0 && common_factors->empty()) break;
866 if (IsControlInput(node->input(i))) {
867 ctrl_deps->push_back(node->input(i));
868 continue;
869 }
870 NodeDef* input;
871 TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
872
873 if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
874 (IsAnyDiv(*input) && has_mul)) {
875 // Break if input is neither a Mul or Div, or if there are both Mul &
876 // Div Ops.
877 common_factors->clear();
878 break;
879 } else if (IsAnyDiv(*input)) {
880 has_div = true;
881 // In case of possible common dividers, we avoid hoisting out if any
882 // input is not float/double, since integer division is not distributive
883 // over addition.
884 const OpInfo::TensorProperties* properties0;
885 const OpInfo::TensorProperties* properties1;
886 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
887 TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
888 if (properties0->dtype() != DT_FLOAT &&
889 properties0->dtype() != DT_DOUBLE &&
890 properties1->dtype() != DT_FLOAT &&
891 properties1->dtype() != DT_DOUBLE) {
892 common_factors->clear();
893 break;
894 }
895 } else if (IsMul(*input)) {
896 has_mul = true;
897 }
898
899 // We only focus on common factors from denominators if any Op is a
900 // Div.
901 std::set<string> factors_i =
902 has_mul ? std::set<string>{input->input(0), input->input(1)}
903 : std::set<string>{input->input(1)};
904 if (i == 0) {
905 std::swap(*common_factors, factors_i);
906 } else {
907 std::set<string> intersection;
908 std::set_intersection(
909 factors_i.begin(), factors_i.end(), common_factors->begin(),
910 common_factors->end(),
911 std::inserter(intersection, intersection.begin()));
912 std::swap(*common_factors, intersection);
913 }
914 for (int i = 2; i < input->input_size(); ++i) {
915 ctrl_deps->push_back(input->input(i));
916 }
917 }
918
919 *common_factor_is_denominator = has_div;
920 return OkStatus();
921 }
922
923 // Gather up the non-shared factors (the y's in the example).
924 // Unless the aggregation is Add, we have to make sure that all the y's
925 // have the same shape since the other aggregation ops do not support
926 // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const927 Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
928 const bool common_factor_is_denominator,
929 bool* shapes_match,
930 std::vector<string>* unique_factors) const {
931 *shapes_match = true;
932 unique_factors->reserve(node->input_size());
933
934 for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
935 const string& input = node->input(i);
936 if (IsControlInput(input)) {
937 break;
938 }
939 NodeDef* inner_node;
940 TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
941 const int unique_factor_index =
942 common_factor_is_denominator
943 ? 0
944 : (inner_node->input(0) == common_factor ? 1 : 0);
945 unique_factors->push_back(inner_node->input(unique_factor_index));
946 if (i > 0 && !IsAdd(*node)) {
947 const OpInfo::TensorProperties* lhs;
948 const OpInfo::TensorProperties* rhs;
949 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
950 TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
951 *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
952 }
953 }
954 return OkStatus();
955 }
956
IsRewritten(const NodeDef * node) const957 bool IsRewritten(const NodeDef* node) const {
958 // if graph rewrite happens in multiple passes without graph pruning between
959 // them, it's possible that rewritten node already exists in a graph
960 return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
961 ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
962 ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
963 ctx().node_map->NodeExists(InnerAddNodeName(node));
964 }
965
966 // keep names of the nodes that were optimized by this stage
967 std::unordered_set<string> rewritten_nodes_;
968 };
969
970 // Binary associative ops can be re-ordered to minimize the number of broadcasts
971 // and the size of a temporary tensors.
972 //
973 // Example: [a, c] - scalars, [b, d] - matrices
974 // @ - binary associative op (Add or Mul)
975 // @* - broadcast
976 //
977 // @ @*
978 // / \ / \
979 // @* @* -> @ @
980 // / \ / \ / \ / \
981 // a b c d a c b d
982 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
983 public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)984 explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
985 const ArithmeticOptimizerContext& ctx_ext)
986 : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
987 }
988 ~MinimizeBroadcasts() override = default;
989
IsSupported(const NodeDef * node) const990 bool IsSupported(const NodeDef* node) const override {
991 if (!IsBinaryAssociative(*node)) return false;
992
993 if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
994 return false;
995
996 // has a symbolically defined shape with broadcastable inputs
997 const OpInfo::TensorProperties* properties;
998 Status has_properties = GetTensorProperties(node->name(), &properties);
999 return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
1000 HasAllInputsBroadcastableToShape(*node, *properties);
1001 }
1002
1003 protected:
IsBinaryAssociative(const NodeDef & node) const1004 bool IsBinaryAssociative(const NodeDef& node) const {
1005 return IsMul(node) || IsAdd(node);
1006 }
1007
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const1008 bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
1009 return group.root_node->op() == node.op();
1010 }
1011
1012 // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const1013 bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
1014 const NodeDef& node) const override {
1015 if (!IsSameOp(group, node)) {
1016 return false;
1017 }
1018 if (IsInPreserveSet(node)) {
1019 return false;
1020 }
1021 // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
1022 if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
1023 return false;
1024 }
1025 if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
1026 return false;
1027 }
1028 if (!IsOnTheSameDevice(group, node)) {
1029 return false;
1030 }
1031 // Optimized nodes updated in place, and that would break the graph, if the
1032 // node has multiple output consumers
1033 if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
1034 return false;
1035 }
1036 // All input shapes must be broadcastable to the node shape
1037 const OpInfo::TensorProperties* properties;
1038 Status has_properties = GetTensorProperties(node.name(), &properties);
1039 return has_properties.ok() &&
1040 HasAllInputsBroadcastableToShape(node, *properties);
1041 }
1042
CountUniqueShapes(const std::vector<InputAndShape> & inputs)1043 std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
1044 std::set<string> sigs;
1045 for (const auto& ias : inputs) {
1046 sigs.insert(ShapeSignature(ias.shape));
1047 }
1048 return sigs.size();
1049 }
1050
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)1051 string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
1052 VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
1053 << " op=" << group.root_node->op()
1054 << " num_optimized_nodes=" << group.optimized_nodes.size();
1055
1056 // Do not optimize any of the nodes that are part of this group.
1057 MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
1058
1059 if (CountUniqueShapes(group.inputs) <= 1) {
1060 VLOG(3) << "Skip min-bcast group with single unique shape";
1061 // nothing to optimize when all shapes are the same
1062 return group.root_node->name();
1063 }
1064
1065 auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1066 auto num_inputs = group.inputs.size();
1067 CHECK_EQ(num_nodes, num_inputs - 1)
1068 << "Can't build a tree with " << num_inputs << " inputs, using "
1069 << num_nodes << "binary op nodes.";
1070
1071 std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1072 std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1073 group.optimized_nodes.end());
1074
1075 // sort inputs by it's shape from smallest to largest
1076 std::stable_sort(add_ops.begin(), add_ops.end(),
1077 [](const InputAndShape& lhs, const InputAndShape& rhs) {
1078 return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1079 rhs.shape);
1080 });
1081
1082 // If there is an odd number of inputs, last one is the largest, and we want
1083 // to attach it to the root node, to build a well balanced tree.
1084 std::deque<InputAndShape> add_ops_leftover;
1085 if (add_ops.size() % 2 != 0) {
1086 add_ops_leftover.push_back(add_ops.back());
1087 add_ops.pop_back();
1088 }
1089
1090 // At this point it's guaranteed that add_ops have even number of inputs.
1091 do {
1092 const InputAndShape lhs = add_ops.front();
1093 add_ops.pop_front();
1094 const InputAndShape rhs = add_ops.front();
1095 add_ops.pop_front();
1096
1097 NodeDef* node;
1098 if (!optimized_nodes.empty()) {
1099 // re-purpose optimized nodes to build a new tree
1100 node = optimized_nodes.back();
1101 optimized_nodes.pop_back();
1102 } else {
1103 // or use root node if none optimized nodes left
1104 node = group.root_node;
1105 }
1106 InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1107
1108 // Pushing updated node to the back of a deque will create a wide and
1109 // short tree, pushing to the front will create a tall tree. We prefer to
1110 // get a wide tree, it minimizes the potential number of temporary tensors
1111 // required to keep in memory, though sometimes we can go up to prevent
1112 // propagating a broadcast from leaves to the root. Example:
1113 //
1114 // inputs: [s, s, s, M] (s - scalar, M - matrix)
1115 // @* - op with broadcast
1116 //
1117 // (only push_back) @* (push_front first op)
1118 // / \
1119 // @* @ M
1120 // / \ / \
1121 // @ @* -> @ s
1122 // / \ / \ / \
1123 // s s s M s s
1124 if (add_ops.size() >= 2 &&
1125 CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1126 add_ops.at(1).shape)) {
1127 add_ops.push_front(updated_node);
1128 } else {
1129 add_ops.push_back(updated_node);
1130 }
1131 } while (add_ops.size() > 1);
1132 CHECK_EQ(1, add_ops.size());
1133
1134 // attach the largest tensor to the root op
1135 if (!add_ops_leftover.empty()) {
1136 const InputAndShape lhs = add_ops.front();
1137 add_ops.pop_front();
1138 const InputAndShape rhs = add_ops_leftover.front();
1139 InputAndShape updated_node =
1140 UpdateInputs(lhs.input, rhs.input, group.root_node);
1141 add_ops.push_back(updated_node);
1142 }
1143
1144 return add_ops.front().input;
1145 }
1146
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1147 InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1148 NodeDef* node) {
1149 string old_input_0 = node->input(0);
1150 string old_input_1 = node->input(1);
1151
1152 // Update inputs only if they changed
1153 if (old_input_0 != input_0 || old_input_1 != input_1) {
1154 node->set_input(0, input_0);
1155 node->set_input(1, input_1);
1156 // Invalidate node properties (shape)
1157 ctx().graph_properties->ClearOutputProperties(node->name());
1158 ctx().graph_properties->ClearInputProperties(node->name());
1159 // Update the node map
1160 ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1161 ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1162 ctx().node_map->AddOutput(NodeName(input_0), node->name());
1163 ctx().node_map->AddOutput(NodeName(input_1), node->name());
1164 // Add updated node to optimization queue
1165 AddToOptimizationQueue(node);
1166 }
1167
1168 TensorShapeProto shape; // shape is not important at this point
1169 return InputAndShape(node->name(), shape);
1170 }
1171 };
1172
1173 // Removes inverse transpose nodes
1174 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1175 public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1176 explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1177 const ArithmeticOptimizerContext& ctx_ext)
1178 : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1179 ~RemoveIdentityTranspose() override = default;
1180
IsSupported(const NodeDef * node) const1181 bool IsSupported(const NodeDef* node) const override {
1182 return IsTranspose(*node) || IsConjugateTranspose(*node);
1183 }
1184
TrySimplify(NodeDef * node,string * simplified_node_name)1185 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1186 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1187 NodeDef* tail = node;
1188 tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1189 *ctx().nodes_to_preserve);
1190 NodeDef* first_transpose;
1191 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1192
1193 NodeDef* node_perm;
1194 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1195 if (!IsConstant(*node_perm)) {
1196 return OkStatus();
1197 }
1198 std::vector<int64_t> node_perm_values;
1199 TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1200 if (first_transpose->op() == node->op()) {
1201 // Remove pairs of transposes that cancel each other.
1202 NodeDef* first_transpose_perm;
1203 TF_RETURN_IF_ERROR(
1204 GetInputNode(first_transpose->input(1), &first_transpose_perm));
1205 if (!IsConstant(*first_transpose_perm)) {
1206 return OkStatus();
1207 }
1208 std::vector<int64_t> first_transpose_perm_values;
1209 TF_RETURN_IF_ERROR(
1210 GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1211 if (AreInversePermutations(node_perm_values,
1212 first_transpose_perm_values)) {
1213 if (tail == node) {
1214 // Bypass adjacent pair.
1215 *simplified_node_name = first_transpose->input(0);
1216 } else {
1217 // Bypass pair connected through chain.
1218 tail->set_input(0, first_transpose->input(0));
1219 ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1220 first_transpose->input(0));
1221 ForwardControlDependencies(tail, {first_transpose});
1222 *simplified_node_name = node->input(0);
1223 }
1224 }
1225 } else {
1226 // Remove simple identity transposes.
1227 if (IsIdentityPermutation(node_perm_values)) {
1228 if (IsConjugateTranspose(*node)) {
1229 const NodeScopeAndName transpose =
1230 ParseNodeScopeAndName(node->name());
1231 const string optimized_node_name = OptimizedNodeName(transpose);
1232 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1233 new_op->set_op("Conj");
1234 new_op->mutable_input()->RemoveLast();
1235 new_op->mutable_attr()->erase("Tperm");
1236 ForwardControlDependencies(new_op, {node});
1237 *simplified_node_name = new_op->name();
1238 } else {
1239 *simplified_node_name = node->input(0);
1240 }
1241 }
1242 }
1243 return OkStatus();
1244 }
1245
1246 private:
GetPermutation(const NodeDef & node_perm,std::vector<int64_t> * perm64) const1247 Status GetPermutation(const NodeDef& node_perm,
1248 std::vector<int64_t>* perm64) const {
1249 std::vector<int> perm32;
1250 if (ValuesFromConstNode(node_perm, &perm32)) {
1251 perm64->reserve(perm32.size());
1252 for (int val : perm32) {
1253 perm64->push_back(static_cast<int64_t>(val));
1254 }
1255 return OkStatus();
1256 }
1257 if (ValuesFromConstNode(node_perm, perm64)) {
1258 return OkStatus();
1259 }
1260 return errors::InvalidArgument("Couldn't extract permutation from ",
1261 node_perm.name());
1262 }
1263
AreInversePermutations(const std::vector<int64_t> & a,const std::vector<int64_t> & b)1264 bool AreInversePermutations(const std::vector<int64_t>& a,
1265 const std::vector<int64_t>& b) {
1266 if (a.size() != b.size()) {
1267 return false;
1268 }
1269 for (int i = 0, end = a.size(); i < end; ++i) {
1270 if (a[b[i]] != i) {
1271 return false;
1272 }
1273 }
1274 return true;
1275 }
1276
IsIdentityPermutation(const std::vector<int64_t> & perm)1277 bool IsIdentityPermutation(const std::vector<int64_t>& perm) {
1278 for (int64_t i = 0, end = perm.size(); i < end; ++i) {
1279 if (i != perm[i]) {
1280 return false;
1281 }
1282 }
1283 return true;
1284 }
1285 };
1286
1287 // An involution is an element-wise function f(x) that is its own inverse,
1288 // i.e. f(f(x)) = x. If we can find a chain of ops
1289 // f->op1->op2->...opn->f
1290 // where op1 through opn preserve the values of their inputs, we can remove
1291 // the two instances of the involution from the graph, since they cancel
1292 // each other.
1293 class RemoveInvolution : public ArithmeticOptimizerStage {
1294 public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1295 explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1296 const ArithmeticOptimizerContext& ctx_ext)
1297 : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1298 ~RemoveInvolution() override = default;
1299
IsSupported(const NodeDef * node) const1300 bool IsSupported(const NodeDef* node) const override {
1301 return IsInvolution(*node);
1302 }
1303
TrySimplify(NodeDef * node,string * simplified_node_name)1304 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1305 NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1306 *ctx().nodes_to_preserve);
1307
1308 NodeDef* involution;
1309 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1310
1311 if (involution->op() == node->op()) {
1312 // Skip both *node and *involution since they cancel each other.
1313 if (tail == node) {
1314 // The two nodes to eliminate are adjacent.
1315 *simplified_node_name = involution->input(0);
1316 } else {
1317 tail->set_input(0, involution->input(0));
1318 ctx().node_map->UpdateInput(tail->name(), involution->name(),
1319 involution->input(0));
1320 *simplified_node_name = node->input(0);
1321 }
1322 }
1323
1324 return OkStatus();
1325 }
1326 };
1327
1328 // Remove redundant Bitcasts.
1329 // 1) Remove Bitcast whose source type and destination type are equal
1330 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1331 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1332 public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1333 explicit RemoveRedundantBitcastStage(
1334 const GraphOptimizerContext& ctx,
1335 const ArithmeticOptimizerContext& ctx_ext)
1336 : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1337 ~RemoveRedundantBitcastStage() override = default;
1338
IsSupported(const NodeDef * node) const1339 bool IsSupported(const NodeDef* node) const override {
1340 return IsBitcast(*node);
1341 }
1342
TrySimplify(NodeDef * node,string * simplified_node_name)1343 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1344 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1345
1346 // Bypass Bitcast whose source type and destination type are equal.
1347 AttrSlice attrs(*node);
1348 DataType input_type;
1349 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1350 DataType output_type;
1351 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1352 if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1353 *simplified_node_name = node->input(0);
1354 return OkStatus();
1355 }
1356
1357 NodeDef* bitcast;
1358 TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1359 NodeDef* operand;
1360 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1361
1362 if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1363 AttrSlice operand_attrs(*operand);
1364 DataType operand_input_type;
1365 TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1366 // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1367 bitcast->set_input(0, operand->input(0));
1368 SetDataTypeToAttr(operand_input_type, "T", bitcast);
1369 ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1370 operand->input(0));
1371 AddToOptimizationQueue(bitcast);
1372 *simplified_node_name = bitcast->name();
1373 }
1374
1375 return OkStatus();
1376 }
1377 };
1378
1379 // Remove Casts whose source type and destination type are equal.
1380 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1381 public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1382 explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1383 const ArithmeticOptimizerContext& ctx_ext)
1384 : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1385 ~RemoveRedundantCastStage() override = default;
1386
IsSupported(const NodeDef * node) const1387 bool IsSupported(const NodeDef* node) const override {
1388 return IsCast(*node) && !IsInPreserveSet(*node);
1389 }
1390
TrySimplify(NodeDef * node,string * simplified_node_name)1391 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1392 TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1393
1394 // Bypass Cast whose source type and destination type are equal.
1395 AttrSlice attrs(*node);
1396 DataType input_type;
1397 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1398 DataType output_type;
1399 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1400 if (input_type == output_type) {
1401 *simplified_node_name = node->input(0);
1402 }
1403 return OkStatus();
1404 }
1405 };
1406
1407 class RemoveNegationStage : public ArithmeticOptimizerStage {
1408 public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1409 explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1410 const ArithmeticOptimizerContext& ctx_ext)
1411 : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1412 ~RemoveNegationStage() override = default;
1413
IsSupported(const NodeDef * node) const1414 bool IsSupported(const NodeDef* node) const override {
1415 return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1416 }
1417
TrySimplify(NodeDef * node,string * simplified_node_name)1418 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1419 NodeDef* x;
1420 NodeDef* y;
1421 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1422 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1423 bool updated = false;
1424 if (IsNeg(*y)) {
1425 // a - (-b) = a + b or a + (-b) = a - b
1426 ForwardControlDependencies(node, {y});
1427 ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1428 node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1429 node->set_input(1, y->input(0));
1430 updated = true;
1431 } else if (IsAdd(*node) && IsNeg(*x)) {
1432 // (-a) + b = b - a
1433 ForwardControlDependencies(node, {x});
1434 ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1435 node->set_op("Sub");
1436 node->mutable_input()->SwapElements(0, 1);
1437 node->set_input(1, x->input(0));
1438 updated = true;
1439 }
1440 if (updated) {
1441 AddToOptimizationQueue(node);
1442 }
1443 return OkStatus();
1444 }
1445 };
1446
1447 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1448 public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1449 explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1450 const ArithmeticOptimizerContext& ctx_ext)
1451 : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1452 ~RemoveLogicalNotStage() override = default;
1453
IsSupported(const NodeDef * node) const1454 bool IsSupported(const NodeDef* node) const override {
1455 return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1456 }
1457
TrySimplify(NodeDef * node,string * simplified_node_name)1458 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1459 const string node_name = node->name();
1460 NodeDef* input;
1461 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1462 if (IsInPreserveSet(*input) ||
1463 NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1464 return OkStatus();
1465 }
1466 string new_op;
1467 if (IsEqual(*input)) {
1468 new_op = "NotEqual";
1469 } else if (IsNotEqual(*input)) {
1470 new_op = "Equal";
1471 } else if (IsLess(*input)) {
1472 new_op = "GreaterEqual";
1473 } else if (IsLessEqual(*input)) {
1474 new_op = "Greater";
1475 } else if (IsGreater(*input)) {
1476 new_op = "LessEqual";
1477 } else if (IsGreaterEqual(*input)) {
1478 new_op = "Less";
1479 }
1480 if (!new_op.empty()) {
1481 input->set_op(new_op);
1482 *simplified_node_name = input->name();
1483 }
1484 return OkStatus();
1485 }
1486 };
1487
1488 // This optimization hoists the common prefix of unary ops of the inputs to
1489 // concat out of the concat, for example:
1490 // Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1491 // becomes
1492 // Exp(Sin(Concat([x, y, z]))).
1493 // Similarly, it will hoist the common postfix of unary ops into Split or
1494 // SplitV nodes, for example:
1495 // [Exp(Sin(y)) for y in Split(x)]
1496 // becomes
1497 // [y for y in Split(Exp(Sin(x))]
1498 //
1499 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1500 // on the concat/split node.
1501 // TODO(rmlarsen): Handle Enter/Exit.
1502 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1503 public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1504 explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1505 const ArithmeticOptimizerContext& ctx_ext)
1506 : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1507
1508 ~HoistCWiseUnaryChainsStage() override = default;
1509
1510 struct ChainLink {
1511 ChainLink() = default;
ChainLinktensorflow::grappler::__anon2c8062030111::HoistCWiseUnaryChainsStage::ChainLink1512 ChainLink(NodeDef* _node, int _port_origin)
1513 : node(_node), port_origin(_port_origin) {}
1514 NodeDef* node; // Node in a chain.
1515 int port_origin; // Port on concat/split node from which this chain
1516 // originates.
1517
operator <tensorflow::grappler::__anon2c8062030111::HoistCWiseUnaryChainsStage::ChainLink1518 bool operator<(const ChainLink& other) const {
1519 if (port_origin < other.port_origin) {
1520 return true;
1521 } else if (port_origin > other.port_origin) {
1522 return false;
1523 } else {
1524 return node->name() < other.node->name();
1525 }
1526 }
1527 };
1528
1529 // We use an ordinary set sorted on port and node name, so the order, and
1530 // hence the node name used for the hoisted chain, will be deterministic.
1531 using ChainLinkSet = std::set<ChainLink>;
1532
IsSupported(const NodeDef * node) const1533 bool IsSupported(const NodeDef* node) const override {
1534 if (IsInPreserveSet(*node)) return false;
1535 if (IsConcat(*node) && node->attr().count("N") != 0) {
1536 const int n = node->attr().at("N").i();
1537 return n > 1 && FirstNInputsAreUnique(*node, n);
1538 } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1539 node->attr().count("num_split") != 0) {
1540 const int num_split = node->attr().at("num_split").i();
1541 if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1542 // TODO(rmlarsen): Remove this constraint when we have optimizations
1543 // in place for merging slices into splits.
1544 return false;
1545 }
1546 if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1547 // TODO(ezhulenev): Unary ops after Split might have a control path to
1548 // the Split node, and we currently do not properly handle cycles.
1549 return false;
1550 }
1551 return num_split > 1 && !IsAlreadyOptimized(*node);
1552 }
1553 return false;
1554 }
1555
TrySimplify(NodeDef * node,string * simplified_node_name)1556 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1557 node_is_concat_ = IsConcat(*node);
1558 int prefix_length;
1559 std::set<string> ctrl_inputs;
1560 ChainLinkSet tails;
1561 TF_RETURN_IF_ERROR(
1562 FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1563 if (prefix_length > 0 && !tails.empty()) {
1564 TF_RETURN_IF_ERROR(
1565 HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1566 }
1567 return OkStatus();
1568 }
1569
1570 private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1571 bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1572 if (n > node.input_size()) return false;
1573 absl::flat_hash_set<string> unique_inputs;
1574 const int start = node.op() == "Concat" ? 1 : 0;
1575 const int end = start + n;
1576 for (int i = start; i < end; ++i) {
1577 unique_inputs.insert(node.input(i));
1578 }
1579 int unique_input_size = unique_inputs.size();
1580 return unique_input_size == n;
1581 }
1582
1583 // Returns the length of the common unary chain of ops that can be
1584 // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1585 Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1586 ChainLinkSet* tails,
1587 std::set<string>* ctrl_inputs) const {
1588 *prefix_length = 0;
1589 // Follow the chains starting at each concat input or split output as long
1590 // as all the following conditions hold:
1591 // 1. The ops in all chains are the same.
1592 // 2. The ops are unary elementwise op.
1593 // 3. The op output has only a single consumer (concat only).
1594 ChainLinkSet cur_tails;
1595 TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1596 if (cur_tails.size() < 2) {
1597 return OkStatus();
1598 }
1599 ctrl_inputs->clear();
1600 bool stop = false;
1601 while (!stop && !cur_tails.empty() &&
1602 OpsAreSafeToHoist(root_node, cur_tails)) {
1603 // We found one more link that can be hoisted.
1604 ++(*prefix_length);
1605 tails->swap(cur_tails);
1606 GatherControlInputs(ctrl_inputs, *tails);
1607
1608 // Advance tail pointers to the next level.
1609 TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1610 }
1611 return OkStatus();
1612 }
1613
1614 // Hoists the chains to the other side of concat or split and attaches the
1615 // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1616 Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1617 std::set<string>* ctrl_inputs, NodeDef* root_node) {
1618 VLOG(3) << "Hoist unary op chain:"
1619 << " root=" << root_node->DebugString()
1620 << " prefix_length=" << prefix_length << " ctrl_inputs=["
1621 << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1622
1623 if (tails.empty()) {
1624 return OkStatus();
1625 }
1626 AddToOptimizationQueue(root_node);
1627 optimized_nodes_.insert(root_node->name());
1628 if (node_is_concat_) {
1629 AddControlInputs(ctrl_inputs, root_node);
1630 return HoistChainForConcat(prefix_length, tails, root_node);
1631 } else {
1632 return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1633 }
1634 }
1635
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1636 void GatherControlInputs(std::set<string>* ctrl_inputs,
1637 const ChainLinkSet& ops) const {
1638 for (const auto& link : ops) {
1639 const NodeDef* node = link.node;
1640 for (int i = node->input_size() - 1; i >= 0; --i) {
1641 const string& input = node->input(i);
1642 if (!IsControlInput(input)) break;
1643 ctrl_inputs->insert(input);
1644 }
1645 }
1646 }
1647
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1648 void AddControlInputs(std::set<string>* new_ctrl_inputs,
1649 NodeDef* node) const {
1650 for (int i = node->input_size() - 1; i >= 0; --i) {
1651 const string& existing_input = node->input(i);
1652 if (!IsControlInput(existing_input)) break;
1653 new_ctrl_inputs->erase(existing_input);
1654 }
1655 for (const string& new_input : *new_ctrl_inputs) {
1656 ctx().node_map->AddOutput(NodeName(new_input), node->name());
1657 node->add_input(new_input);
1658 }
1659 }
1660
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1661 Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1662 if (node_is_concat_) {
1663 // Handle concat nodes by looking backwards in the graph.
1664 TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1665 const int n = node.attr().at("N").i();
1666 const int start = node.op() == "Concat" ? 1 : 0;
1667 const int end = start + n;
1668 if (end > node.input_size()) {
1669 return errors::FailedPrecondition("Got attr N=", n,
1670 " without enough inputs.");
1671 }
1672 // Set up tail pointers to point to the immediate inputs to Concat.
1673 for (int input_port = start; input_port < end; ++input_port) {
1674 if (IsControlInput(node.input(input_port))) {
1675 return errors::FailedPrecondition(
1676 "Got control input ", node.input(input_port),
1677 " where normal input was expected.");
1678 }
1679 NodeDef* tail;
1680 TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1681 tails->insert(ChainLink(tail, input_port));
1682 }
1683 return OkStatus();
1684 } else {
1685 // Handle split nodes by looking forwards in the graph.
1686 const auto& outputs = ctx().node_map->GetOutputs(node.name());
1687 for (NodeDef* output : outputs) {
1688 if (output->input_size() == 0 || IsControlInput(output->input(0))) {
1689 continue;
1690 }
1691 TensorId tensor_id = ParseTensorName(output->input(0));
1692 if (tensor_id.node() == node.name()) {
1693 tails->insert(ChainLink(output, tensor_id.index()));
1694 } else {
1695 // This output node has a non-control input other than the split node,
1696 // abort.
1697 tails->clear();
1698 return OkStatus();
1699 }
1700 }
1701 }
1702 return OkStatus();
1703 }
1704
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1705 bool OpsAreSafeToHoist(const NodeDef& root_node,
1706 const ChainLinkSet& ops) const {
1707 if (ops.empty()) return true;
1708 const NodeDef* op0 = ops.begin()->node;
1709 if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1710 for (const auto& link : ops) {
1711 const NodeDef* op = link.node;
1712 if (op->device() != root_node.device() || op->op() != op0->op() ||
1713 IsInPreserveSet(*op)) {
1714 return false;
1715 }
1716 if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1717 // TODO(rmlarsen): Allow outgoing control edges.
1718 return false;
1719 }
1720 // Do not hoist Relu if it can be fused with its predecessors. This is
1721 // important because remapping runs after arithmetic.
1722 if (IsRelu(*op) || IsRelu6(*op)) {
1723 NodeDef* operand = nullptr;
1724 if (!GetInputNode(op->input(0), &operand).ok()) {
1725 return false;
1726 }
1727 if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1728 return false;
1729 }
1730 }
1731 }
1732 return true;
1733 }
1734
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1735 Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1736 bool* stop) const {
1737 *stop = true;
1738 new_tails->clear();
1739 for (const auto& link : tails) {
1740 const NodeDef* tail = link.node;
1741 if (node_is_concat_) {
1742 if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1743 return OkStatus();
1744 }
1745 NodeDef* new_tail;
1746 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1747 // Remember original port.
1748 new_tails->insert(ChainLink(new_tail, link.port_origin));
1749 } else {
1750 for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1751 const TensorId tensor = ParseTensorName(new_tail->input(0));
1752 if (tensor.node() != tail->name()) {
1753 return OkStatus();
1754 }
1755 // Skip control outputs.
1756 if (tensor.index() >= 0) {
1757 // Remember original port.
1758 new_tails->insert(ChainLink(new_tail, link.port_origin));
1759 }
1760 }
1761 }
1762 }
1763 *stop = false;
1764 return OkStatus();
1765 }
1766
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1767 Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1768 NodeDef* concat_node) {
1769 const string& concat_name = concat_node->name();
1770 const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1771 for (const auto& link : tails) {
1772 NodeDef* tail = CHECK_NOTNULL(link.node);
1773 const int concat_port = link.port_origin;
1774 CHECK_GE(concat_port, 0);
1775 CHECK_LT(concat_port, concat_node->input_size());
1776 const string concat_input = concat_node->input(concat_port);
1777 // Hook the node following tail directly into the concat node.
1778 const string tail_input = tail->input(0);
1779 concat_node->set_input(concat_port, tail_input);
1780 ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1781
1782 if (concat_port == first_input) {
1783 // Update the consumers of concat to consume the end of the chain
1784 // instead.
1785 TF_RETURN_IF_ERROR(UpdateConsumers(concat_node, concat_input));
1786 // Reuse nodes in the first chain to process output of concat.
1787 tail->set_input(0, concat_name);
1788 ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1789 }
1790 }
1791 return OkStatus();
1792 }
1793
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1794 Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1795 std::set<string>* ctrl_inputs,
1796 NodeDef* split_node) {
1797 // Create a new chain before the split node to process the input tensor.
1798 const string& split_name = split_node->name();
1799 auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1800
1801 // We use the first tail node in the set as a template to get the list of
1802 // ops to apply (starting from the end).
1803 NodeDef* cur_tail = tails.begin()->node;
1804 NodeDef* cur_copy = AddCopyNode(
1805 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1806 cur_copy->clear_input();
1807
1808 // Update the split to take its input from the tail of the new chain.
1809 const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1810 const string orig_input = split_node->input(value_slot);
1811 split_node->set_input(value_slot, cur_copy->name());
1812 ctx().node_map->UpdateInput(split_node->name(), orig_input,
1813 cur_copy->name());
1814 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1815
1816 // Now walk backwards creating the rest of the chain.
1817 while (cur_tail != split_node) {
1818 NodeDef* new_copy = AddCopyNode(
1819 OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1820 new_copy->clear_input();
1821 cur_copy->add_input(new_copy->name());
1822 ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1823 cur_copy = new_copy;
1824 TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1825 }
1826 // Connect the original input to the head of the new chain.
1827 cur_copy->add_input(orig_input);
1828 ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1829 cur_copy->name());
1830 // Make sure all the control inputs are satisfied before running the first
1831 // node in the new chain.
1832 AddControlInputs(ctrl_inputs, cur_copy);
1833
1834 // Connect all consumers of the tail nodes directly to the
1835 // output port of Split from which the chain started.
1836 for (const auto& link : tails) {
1837 TF_RETURN_IF_ERROR(UpdateConsumers(
1838 link.node, link.port_origin == 0
1839 ? split_name
1840 : strings::StrCat(split_name, ":", link.port_origin)));
1841 }
1842 return OkStatus();
1843 }
1844
IsAlreadyOptimized(const NodeDef & node) const1845 bool IsAlreadyOptimized(const NodeDef& node) const {
1846 return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1847 }
1848
1849 private:
1850 bool node_is_concat_;
1851 std::unordered_set<string> optimized_nodes_;
1852 };
1853
1854 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1855 public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1856 explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1857 const ArithmeticOptimizerContext& ctx_ext)
1858 : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1859 ~RemoveIdempotentStage() override = default;
1860
IsSupported(const NodeDef * node) const1861 bool IsSupported(const NodeDef* node) const override {
1862 return node->input_size() == 1 && IsIdempotent(*node) &&
1863 !IsInPreserveSet(*node);
1864 }
1865
TrySimplify(NodeDef * node,string * simplified_node_name)1866 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1867 NodeDef* input;
1868 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1869 if (input->op() == node->op() && input->device() == node->device()) {
1870 *simplified_node_name = node->input(0);
1871 }
1872 return OkStatus();
1873 }
1874 };
1875
1876 // Performs the conversion:
1877 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1878 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1879 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1880 public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1881 explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1882 const ArithmeticOptimizerContext& ctx_ext)
1883 : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1884 ~SqrtDivToRsqrtMulStage() override = default;
1885
IsSupported(const NodeDef * node) const1886 bool IsSupported(const NodeDef* node) const override {
1887 // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1888 // for b == 0 would result in a / Inf instead of 0.
1889 return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1890 }
1891
TrySimplify(NodeDef * node,string * simplified_node_name)1892 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1893 NodeDef* y;
1894 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1895 // Optimize only if divisor is a Sqrt whose output is not being consumed
1896 // elsewhere.
1897 if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1898 (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1899 if (IsXdivy(*node)) {
1900 // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1901 node->set_op("MulNoNan");
1902 node->mutable_input()->SwapElements(0, 1);
1903 } else {
1904 // div(a, sqrt(b)) => mul(a, rsqrt(b))
1905 node->set_op("Mul");
1906 }
1907 y->set_op("Rsqrt");
1908 AddToOptimizationQueue(node);
1909 AddToOptimizationQueue(y);
1910 }
1911 return OkStatus();
1912 }
1913 };
1914
1915 // Performs the following conversion for real types:
1916 // Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1917 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1918 public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1919 explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1920 const ArithmeticOptimizerContext& ctx_ext)
1921 : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1922 ~FuseSquaredDiffStage() override = default;
1923
IsSupported(const NodeDef * node) const1924 bool IsSupported(const NodeDef* node) const override {
1925 return IsSquare(*node);
1926 }
1927
TrySimplify(NodeDef * node,string * simplified_node_name)1928 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1929 NodeDef* b;
1930 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1931 // Optimize only if base is a Sub whose output is not being consumed
1932 // elsewhere.
1933 if (IsSub(*b) && !IsInPreserveSet(*b) &&
1934 (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1935 // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1936 // invalid.
1937 const DataType type = GetDataTypeFromAttr(*b, "T");
1938 if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128)) return OkStatus();
1939 node->set_op("Identity");
1940 b->set_op("SquaredDifference");
1941 AddToOptimizationQueue(node);
1942 AddToOptimizationQueue(b);
1943 }
1944 return OkStatus();
1945 }
1946 };
1947
1948 // Performs the conversion:
1949 // Log(Softmax(x)) => LogSoftmax(x)
1950 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1951 public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1952 explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1953 const ArithmeticOptimizerContext& ctx_ext)
1954 : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1955 ~LogSoftmaxStage() override = default;
1956
IsSupported(const NodeDef * node) const1957 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1958
TrySimplify(NodeDef * node,string * simplified_node_name)1959 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1960 NodeDef* x;
1961 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1962 // Optimize only if arg is a Softmax whose output is not being consumed
1963 // elsewhere.
1964 if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1965 (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1966 // Log(Softmax(x)) => LogSoftmax(Identity(x))
1967 node->set_op("LogSoftmax");
1968 x->set_op("Identity");
1969 AddToOptimizationQueue(node);
1970 AddToOptimizationQueue(x);
1971 }
1972 return OkStatus();
1973 }
1974 };
1975
1976 // Bypass redundant reshape nodes:
1977 //
1978 // Reshape Reshape <-+
1979 // ^ |
1980 // | |
1981 // Reshape becomes Reshape |
1982 // ^ |
1983 // | |
1984 // input input ---+
1985 //
1986 // Additionally, Reshape and BroadcastTo nodes where the
1987 // input and target shapes are equal are bypassed.
1988 //
1989 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1990 public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1991 explicit RemoveRedundantReshapeOrBroadcastTo(
1992 const GraphOptimizerContext& ctx,
1993 const ArithmeticOptimizerContext& ctx_ext)
1994 : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1995 ctx_ext) {}
1996 ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1997
IsSupported(const NodeDef * node) const1998 bool IsSupported(const NodeDef* node) const override {
1999 return IsReshape(*node) || IsBroadcastTo(*node);
2000 }
2001
2002 // TODO(rmlarsen): Handle unary ops with multiple outputs.
TrySimplify(NodeDef * node,string * simplified_node_name)2003 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2004 // 1. If the reshape is a no-op, forward its input to its consumers, unless
2005 // it anchors a control dependency since we want to make sure that control
2006 // dependency is triggered.
2007 if (!IsInPreserveSet(*node) && InputMatchesTargetShape(*node) &&
2008 !HasControlInputs(*node)) {
2009 *simplified_node_name = node->input(0);
2010 return OkStatus();
2011 }
2012
2013 // 2. Bypass reshape followed by reshape, possibly separated by a simple
2014 // chain of unary elementwise ops that are not outputs.
2015 if (IsReshape(*node)) {
2016 bool skip = false;
2017 gtl::InlinedVector<const NodeDef*, 4> nodes_in_chain;
2018 const auto predicate_fn = [this, node, &skip,
2019 &nodes_in_chain](const NodeDef& input) {
2020 nodes_in_chain.push_back(&input);
2021 if ((input.name() != node->name() &&
2022 NumNonControlOutputs(input, *ctx().node_map) > 1) ||
2023 IsInPreserveSet(input) || ModifiesFrameInfo(input)) {
2024 skip = true;
2025 return false;
2026 }
2027 return IsUnaryElementWise(input);
2028 };
2029
2030 // Walk up the input chain until we find a node that is not unary
2031 // element-wise. If it is another Reshape node, we can bypass it.
2032 NodeDef* tail =
2033 GetTailOfChain(*node, *ctx().node_map,
2034 /*follow_control_input*/ false, predicate_fn);
2035
2036 if (!skip && tail != nullptr && !IsInPreserveSet(*tail)) {
2037 NodeDef* reshape_to_bypass;
2038 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &reshape_to_bypass));
2039 if (reshape_to_bypass == nullptr ||
2040 (!IsReshape(*reshape_to_bypass) ||
2041 NumNonControlOutputs(*reshape_to_bypass, *ctx().node_map) > 1 ||
2042 IsInPreserveSet(*reshape_to_bypass))) {
2043 return OkStatus();
2044 }
2045 // Clearing invalid shape inference results of nodes in chain.
2046 for (const NodeDef* node_in_chain : nodes_in_chain) {
2047 ctx().graph_properties->ClearInputProperties(node_in_chain->name());
2048 if (node_in_chain != node) {
2049 ctx().graph_properties->ClearOutputProperties(
2050 node_in_chain->name());
2051 }
2052 }
2053 // We now have
2054 // reshape_to_bypass -> tail -> ... -> node
2055 // where tail maybe equal to node.
2056 TF_RETURN_IF_ERROR(
2057 UpdateConsumers(reshape_to_bypass, reshape_to_bypass->input(0)));
2058 ForwardControlDependencies(tail, {reshape_to_bypass});
2059 // Change the bypassed reshape to NoOp.
2060 ReplaceWithNoOp(reshape_to_bypass, ctx());
2061 *simplified_node_name = node->name();
2062 return OkStatus();
2063 }
2064 }
2065
2066 return OkStatus();
2067 }
2068
2069 private:
2070 // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)2071 bool InputMatchesTargetShape(const NodeDef& reshape) {
2072 const OpInfo::TensorProperties* reshape_props;
2073 const OpInfo::TensorProperties* input_props;
2074 if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
2075 !GetTensorProperties(reshape.input(0), &input_props).ok()) {
2076 return false;
2077 }
2078
2079 return ShapesSymbolicallyEqual(input_props->shape(),
2080 reshape_props->shape());
2081 }
2082 };
2083
2084 // Reorder casting and value-preserving ops if beneficial.
2085 //
2086 // Original motivation: A common pattern after the layout optimizer is
2087 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2088 // is beneficial to reorder the cast and the transpose to make the transpose
2089 // process smaller amount of data. More generally, this optimization converts
2090 // Op(Cast(tensor, dst_type))
2091 // to
2092 // Cast(Op(tensor), dst_type)
2093 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2094 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2095 // this optimization converts
2096 // Cast(Op(tensor), dst_type)
2097 // to
2098 // Op(Cast(tensor, dst_type))
2099 // when sizeof(tensor.type) > sizeof(dst_type)
2100 //
2101 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2102 public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2103 explicit ReorderCastLikeAndValuePreserving(
2104 const GraphOptimizerContext& ctx,
2105 const ArithmeticOptimizerContext& ctx_ext)
2106 : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2107 ctx_ext) {}
2108 ~ReorderCastLikeAndValuePreserving() override = default;
2109
IsSupported(const NodeDef * node) const2110 bool IsSupported(const NodeDef* node) const override {
2111 return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2112 !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2113 !IsControlFlow(*node) && !IsInPreserveSet(*node);
2114 }
2115
TrySimplify(NodeDef * consumer,string * simplified_node_name)2116 Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2117 NodeDef* producer;
2118
2119 if (consumer->input_size() < 1) {
2120 return errors::FailedPrecondition("Node ", simplified_node_name,
2121 " lacks inputs");
2122 }
2123
2124 TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2125 const bool producer_is_cast = IsCastLike(*producer);
2126 const bool can_optimize =
2127 !IsCheckNumerics(*producer) &&
2128 ((producer_is_cast && IsValuePreserving(*consumer)) ||
2129 (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2130 if (!can_optimize || IsControlFlow(*producer) ||
2131 IsInPreserveSet(*producer) ||
2132 producer->device() != consumer->device()) {
2133 return OkStatus();
2134 }
2135
2136 const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2137 const OpDef* cast_like_op_def = nullptr;
2138 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2139 &cast_like_op_def));
2140 DataType cast_src_type;
2141 TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2142 &cast_src_type));
2143 DataType cast_dst_type;
2144 TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2145 &cast_dst_type));
2146 if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2147 return OkStatus();
2148 } else if (producer_is_cast &&
2149 DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2150 return OkStatus();
2151 } else if (!producer_is_cast &&
2152 DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2153 return OkStatus();
2154 }
2155
2156 // Check that nodes were not already optimized.
2157 const string optimized_producer_name = OptimizedNodeName(
2158 ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2159 const string optimized_consumer_name = OptimizedNodeName(
2160 ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2161 const bool is_already_optimized =
2162 ctx().node_map->NodeExists(optimized_consumer_name) ||
2163 ctx().node_map->NodeExists(optimized_producer_name);
2164 if (is_already_optimized) {
2165 return OkStatus();
2166 }
2167
2168 // Add copies of consumer and producer in reverse order.
2169 NodeDef* input;
2170 TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2171 // Create new producer node.
2172 NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2173 new_producer->set_input(0, producer->input(0));
2174 ctx().node_map->AddOutput(input->name(), new_producer->name());
2175
2176 // Create new consumer node.
2177 NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2178 new_consumer->set_input(0, new_producer->name());
2179
2180 NodeDef* new_value_preserving =
2181 producer_is_cast ? new_producer : new_consumer;
2182 const DataType new_input_type =
2183 producer_is_cast ? cast_src_type : cast_dst_type;
2184 // Update the input type of the value-preserving node. The input and
2185 // output types of the cast-like nodes remain the same.
2186 TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2187 // Make sure there is a kernel registered for the value preserving op
2188 // with the new input type.
2189 TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2190 ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2191
2192 AddToOptimizationQueue(new_producer);
2193 *simplified_node_name = new_consumer->name();
2194
2195 return OkStatus();
2196 }
2197
2198 private:
2199 // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2200 Status SetInputType(DataType dtype, NodeDef* node) {
2201 const OpDef* op_def = nullptr;
2202 TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2203 const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2204 const string& type_attr_name = input_arg.type_attr();
2205 if (type_attr_name.empty()) {
2206 if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2207 return errors::InvalidArgument("Could not set input type of ",
2208 node->op(), " op to ",
2209 DataTypeString(dtype));
2210 } else {
2211 // Op has fixed input type that already matches dtype.
2212 return OkStatus();
2213 }
2214 }
2215 SetDataTypeToAttr(dtype, type_attr_name, node);
2216 return OkStatus();
2217 }
2218 // This optimization can be dangerous on devices other than CPU and
2219 // GPU. The transpose might not be implemented for image.type, or
2220 // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2221 bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2222 using absl::StrContains;
2223
2224 string task;
2225 string device;
2226
2227 return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2228 (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2229 }
2230
IsFixedSizeType(DataType dtype)2231 bool IsFixedSizeType(DataType dtype) {
2232 return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2233 !kQuantizedTypes.Contains(dtype);
2234 }
2235 };
2236
2237 // Fold a multiply of a scalar into the following convolution. This folding
2238 // can jump across nodes that merely reorders data (such as reshape and
2239 // transpose). For example, we can optimize
2240 //
2241 //
2242 // Conv2D Conv2D
2243 // / \ / \
2244 // Transpose weights* -> Transpose Mul
2245 // | | / \
2246 // Mul | weights scale
2247 // / \ |
2248 // input scale** input
2249 //
2250 // *) weights must be a const
2251 // **) scale must be a const scalar
2252 //
2253 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2254 // constant-folded, also weights tend to be smaller than the activations.
2255 //
2256 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2257 // Conv?DBackpropInput.
2258 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2259 public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2260 explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2261 const ArithmeticOptimizerContext& ctx_ext)
2262 : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2263 ~FoldMultiplyIntoConv() override = default;
2264
IsSupported(const NodeDef * node) const2265 bool IsSupported(const NodeDef* node) const override {
2266 return IsConv2D(*node) || IsConv3D(*node);
2267 }
2268
TrySimplify(NodeDef * node,string * simplified_node_name)2269 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2270 #define TF_RETURN_IF_TRUE(...) \
2271 if ((__VA_ARGS__)) return OkStatus()
2272
2273 NodeDef* conv = node;
2274
2275 NodeDef* weights;
2276 TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2277
2278 // Fold the multiply to conv only when the weights are constant, so the
2279 // multiply can be constant-folded.
2280 //
2281 // TODO(jingyue): When the weights aren't constant, this should also help
2282 // performance a bit and memory usage a lot, since the weights tend to be
2283 // smaller than the activations.
2284 TF_RETURN_IF_TRUE(!IsConstant(*weights));
2285
2286 // Verify that this node was not already optimized.
2287 const string scaled_weights_node_name =
2288 OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2289 strings::StrCat("scaled", "_", conv->name()));
2290
2291 TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2292
2293 // Find the tail of value preserving chain entering the Conv node.
2294 NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2295 *ctx().nodes_to_preserve);
2296
2297 NodeDef* source;
2298 TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2299
2300 // Check that value preserving chain is the only consumer of the Mul output.
2301 TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2302 TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2303 // And that Mul is not in the preserve set.
2304 TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2305
2306 const NodeDef* mul = source;
2307 int input_idx = 0;
2308 int scale_idx = 1;
2309 NodeDef* scale; // scalar multiplier for the input tensor
2310 NodeDef* input;
2311 TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2312 TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2313 if (!IsConstant(*scale) && IsConstant(*input)) {
2314 VLOG(3) << "Swapped inputs to mul";
2315 std::swap(scale_idx, input_idx);
2316 std::swap(scale, input);
2317 }
2318 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2319
2320 // Check that one of the inputs to mul is a constant scalar.
2321 const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2322 bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2323 scale_tensor.tensor_shape().dim_size() == 0;
2324 TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2325
2326 // Check that 'scale * weight' can be const folded.
2327 TF_RETURN_IF_TRUE(!IsConstant(*scale));
2328 TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2329 TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2330 TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2331 weights->attr().at("dtype").type());
2332
2333 // At this point all preconditions are met, and we safely do the rewrite.
2334 VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2335 << " mul=" << mul->name() << " weights=" << weights->name();
2336
2337 // Create new node `scaled_weights`.
2338 NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2339 scaled_weights->set_op(source->op());
2340 scaled_weights->set_device(weights->device());
2341 (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2342 AddToOptimizationQueue(scaled_weights);
2343
2344 // Link in its inputs.
2345 scaled_weights->add_input(conv->input(1));
2346 ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2347 scaled_weights->add_input(mul->input(scale_idx));
2348 ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2349 ForwardControlDependencies(scaled_weights, {source});
2350
2351 // Update `conv`'s weights to `scaled_weights`.
2352 conv->set_input(1, scaled_weights->name());
2353 ctx().node_map->UpdateInput(conv->name(), weights->name(),
2354 scaled_weights->name());
2355 AddToOptimizationQueue(conv);
2356
2357 // Update `tail` node to bypass `mul` because it's folded to the weights.
2358 tail->set_input(0, mul->input(input_idx));
2359 ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2360 AddToOptimizationQueue(tail);
2361 *simplified_node_name = conv->name();
2362
2363 return OkStatus();
2364 #undef TF_RETURN_IF_TRUE
2365 }
2366 };
2367
2368 // Fold Transpose into matrix multiplication.
2369 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2370 public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2371 explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2372 const ArithmeticOptimizerContext& ctx_ext)
2373 : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2374 ~FoldTransposeIntoMatMul() override = default;
2375
IsSupported(const NodeDef * node) const2376 bool IsSupported(const NodeDef* node) const override {
2377 return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2378 }
2379
TrySimplify(NodeDef * node,string * simplified_node_name)2380 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2381 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2382 const string optimized_node_name = OptimizedNodeName(matmul);
2383 if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2384
2385 NodeDef* a;
2386 NodeDef* b;
2387 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2388 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2389
2390 bool is_complex = false;
2391 if (node->op() != "SparseMatMul") {
2392 const DataType type = GetDataTypeFromAttr(*node, "T");
2393 is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2394 }
2395
2396 const std::set<string> foldable_transpose_ops =
2397 !is_complex
2398 ? std::set<string>{"ConjugateTranspose", "Transpose"}
2399 : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2400 : std::set<string>{"Transpose"});
2401
2402 const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2403 IsInnerMatrixTransposeNode(*a, ctx().node_map);
2404 const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2405 IsInnerMatrixTransposeNode(*b, ctx().node_map);
2406 if (!a_is_foldable && !b_is_foldable) return OkStatus();
2407
2408 NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2409
2410 if (a_is_foldable) {
2411 const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2412 FlipBooleanAttr(attr_a, new_op);
2413 new_op->set_input(0, a->input(0));
2414 ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2415 } else {
2416 ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2417 }
2418
2419 if (b_is_foldable) {
2420 const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2421 FlipBooleanAttr(attr_b, new_op);
2422 new_op->set_input(1, b->input(0));
2423 ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2424 } else {
2425 ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2426 }
2427
2428 std::vector<const NodeDef*> deps_to_forward = {node};
2429 if (a_is_foldable) deps_to_forward.push_back(a);
2430 if (b_is_foldable) deps_to_forward.push_back(b);
2431 ForwardControlDependencies(new_op, deps_to_forward);
2432 *simplified_node_name = new_op->name();
2433
2434 return OkStatus();
2435 }
2436
2437 private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2438 void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2439 const bool old_value =
2440 !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2441 (*node->mutable_attr())[attr_name].set_b(!old_value);
2442 }
2443
2444 template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2445 bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2446 const T n = perm.size();
2447 if (n < 2) {
2448 return false;
2449 }
2450 for (T i = 0; i < n - 2; ++i) {
2451 if (perm[i] != i) {
2452 return false;
2453 }
2454 }
2455 return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2456 }
2457
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2458 bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2459 const NodeMap* node_map) {
2460 if (transpose_node.op() != "Transpose" &&
2461 transpose_node.op() != "ConjugateTranspose") {
2462 return false;
2463 }
2464 const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2465 std::vector<int> perm32;
2466 if (ValuesFromConstNode(*perm_node, &perm32)) {
2467 return IsInnerMatrixTranspose(perm32);
2468 }
2469 std::vector<int64_t> perm64;
2470 if (ValuesFromConstNode(*perm_node, &perm64)) {
2471 return IsInnerMatrixTranspose(perm64);
2472 }
2473 return false;
2474 }
2475 };
2476
2477 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2478 public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2479 explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2480 const ArithmeticOptimizerContext& ctx_ext)
2481 : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2482 ~FoldConjugateIntoTranspose() override = default;
2483
IsSupported(const NodeDef * node) const2484 bool IsSupported(const NodeDef* node) const override {
2485 return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2486 }
2487
TrySimplify(NodeDef * node,string * simplified_node_name)2488 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2489 const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2490 const string optimized_node_name = OptimizedNodeName(matmul);
2491 if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2492
2493 NodeDef* input;
2494 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2495
2496 const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2497 const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2498
2499 if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2500 IsConj(*conj_op)) {
2501 NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2502
2503 // Flip the type of transpose op to absorb the conjugation.
2504 new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2505 : "Transpose");
2506 new_op->set_input(0, input->input(0));
2507 ctx().node_map->UpdateInput(new_op->name(), node->name(),
2508 input->input(0));
2509 ForwardControlDependencies(new_op, {node, input});
2510 *simplified_node_name = new_op->name();
2511 }
2512
2513 return OkStatus();
2514 }
2515 };
2516
2517 // Replace Mul node with identical inputs with a Square.
2518 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2519 public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2520 explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2521 const ArithmeticOptimizerContext& ctx_ext)
2522 : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2523 ~ReplaceMulWithSquare() override = default;
2524
IsSupported(const NodeDef * node) const2525 bool IsSupported(const NodeDef* node) const override {
2526 if (!node || node->input_size() < 2) {
2527 // Invalid node
2528 return false;
2529 }
2530
2531 return IsAnyMul(*node) && node->input(0) == node->input(1);
2532 }
2533
TrySimplify(NodeDef * node,string * simplified_node_name)2534 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2535 const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2536 const string optimized_node_name = OptimizedNodeName(mul);
2537 if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2538
2539 const DataType type = GetDataTypeFromAttr(*node, "T");
2540 bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2541
2542 if (!is_complex || NodeIsOnCpu(*node)) {
2543 NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2544 new_square_node->set_op("Square");
2545 for (int i = 1; i < new_square_node->input_size(); ++i) {
2546 new_square_node->set_input(i - 1, new_square_node->input(i));
2547 }
2548 new_square_node->mutable_input()->RemoveLast();
2549 for (const string& input : new_square_node->input()) {
2550 ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2551 }
2552 *simplified_node_name = new_square_node->name();
2553 }
2554
2555 return OkStatus();
2556 }
2557 };
2558
2559 // Replace a combination of Mul with broadcasting by Tile. E.g. replace
2560 //
2561 // input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
2562 // Ones (1x22x2x48x2x64) -^
2563 //
2564 // with
2565 //
2566 // input -> Tile(1x22x2x48x2x64) -> output
2567 class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
2568 public:
ReplaceMulWithBroadcastByTile(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2569 explicit ReplaceMulWithBroadcastByTile(
2570 const GraphOptimizerContext& ctx,
2571 const ArithmeticOptimizerContext& ctx_ext)
2572 : ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
2573 ctx_ext) {}
2574 ~ReplaceMulWithBroadcastByTile() override = default;
2575
IsSupported(const NodeDef * node) const2576 bool IsSupported(const NodeDef* node) const override {
2577 return IsMul(*node) && !IsInPreserveSet(*node);
2578 }
2579
TrySimplify(NodeDef * node,string * simplified_node_name)2580 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2581 NodeDef *input, *ones;
2582 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2583 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2584 if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
2585 IsInPreserveSet(*ones)) {
2586 return OkStatus();
2587 }
2588
2589 // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
2590 if (IsConstant(*input) || !IsOnes(*ones)) return OkStatus();
2591
2592 // Avoid optimizing the same node twice
2593 const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2594 const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
2595 const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
2596 if (ctx().node_map->NodeExists(tile_node_name) ||
2597 ctx().node_map->NodeExists(const_node_name)) {
2598 return OkStatus();
2599 }
2600
2601 const std::vector<OpInfo::TensorProperties>& props =
2602 ctx().graph_properties->GetInputProperties(node->name());
2603 if (props.size() != 2) return OkStatus();
2604
2605 // Ignore ops where the shape doesn't change
2606 const TensorShapeProto& input_shape = props[0].shape();
2607 const TensorShapeProto& ones_shape = props[1].shape();
2608 TensorShapeProto output_shape;
2609 if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
2610 return OkStatus();
2611 }
2612 if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
2613 return OkStatus();
2614 }
2615
2616 // All inputs must have same input/output dimensions
2617 if (input_shape.dim_size() != output_shape.dim_size() ||
2618 ones_shape.dim_size() != output_shape.dim_size())
2619 return OkStatus();
2620
2621 // At this point all preconditions are met. Can proceed with rewrite.
2622 VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
2623 << "@" << output_shape << " ones=" << ones->name() << "@"
2624 << ones_shape << " input=" << input->name() << "@" << input_shape;
2625
2626 // 1. Create constant node with correct tile multiples
2627 Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
2628 for (int i = 0; i < output_shape.dim_size(); ++i) {
2629 int64_t size = output_shape.dim(i).size() / input_shape.dim(i).size();
2630 if (TF_PREDICT_FALSE(size >= INT_MAX)) {
2631 return Status(error::OUT_OF_RANGE, "int32 overflow");
2632 }
2633 multiples.flat<int32>()(i) = static_cast<int32>(size);
2634 }
2635
2636 NodeDef* const_node = AddEmptyNode(const_node_name);
2637 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2638 const_node->name(), TensorValue(&multiples), const_node));
2639 const_node->set_device(node->device());
2640 ForwardControlDependencies(const_node, {ones});
2641 AddToOptimizationQueue(const_node);
2642
2643 // 2. Replace multiply node with Tile(Const, input);
2644 const DataType type = GetDataTypeFromAttr(*node, "T");
2645 NodeDef* tile_node = AddEmptyNode(tile_node_name);
2646 tile_node->set_op("Tile");
2647 tile_node->set_device(node->device());
2648 SetDataTypeToAttr(type, "T", tile_node);
2649 SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
2650 tile_node->add_input(input->name());
2651 tile_node->add_input(const_node->name());
2652
2653 ForwardControlDependencies(tile_node, {node});
2654 *simplified_node_name = tile_node->name();
2655
2656 return OkStatus();
2657 }
2658
2659 protected:
IsOnes(const NodeDef & node) const2660 bool IsOnes(const NodeDef& node) const {
2661 if (!IsReallyConstant(node)) return false;
2662 if (node.attr().at("dtype").type() != DT_FLOAT) return false;
2663
2664 Tensor tensor;
2665 if (!tensor.FromProto(node.attr().at("value").tensor())) {
2666 return false;
2667 }
2668
2669 auto values = tensor.flat<float>();
2670 for (int i = 0; i < tensor.NumElements(); ++i) {
2671 if (values(i) != 1.0f) {
2672 return false;
2673 }
2674 }
2675
2676 return true;
2677 }
2678 };
2679
2680 // Image upsampling often produces an unnecessary reshape that is difficult to
2681 // eliminate in other stages. This stage reduces the number of dimensions
2682 // involved allowing the reshape to be removed.
2683 //
2684 // For example, given
2685 // B,W,H,C -> Reshape(B,W,1,H,1,C) -> Tile(1,1,2,1,2,1) -> Reshape(B,2W,2H,C)
2686 // this pass converts the sequence to
2687 // B,W,H,C -> Reshape(B,W,H,C) -> Tile(1,1,2,2) -> Reshape(B,2W,2H,C)
2688 //
2689 // The first reshape is now redundant and can be removed in a later pass.
2690 //
2691 // Note: This only optimizes the simple (but extremely common) case of 2D
2692 // upsampling.
2693 //
2694 // TODO(kkiningh): Generalize to more complex upsampling patterns.
2695 class ReduceUpsamplingDims : public ArithmeticOptimizerStage {
2696 public:
ReduceUpsamplingDims(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2697 explicit ReduceUpsamplingDims(const GraphOptimizerContext& ctx,
2698 const ArithmeticOptimizerContext& ctx_ext)
2699 : ArithmeticOptimizerStage("ReduceUpsamplingDims", ctx, ctx_ext) {}
2700 ~ReduceUpsamplingDims() override = default;
2701
IsSupported(const NodeDef * node) const2702 bool IsSupported(const NodeDef* node) const override {
2703 return IsReshape(*node) && !IsInPreserveSet(*node);
2704 }
2705
TrySimplify(NodeDef * node,string * simplified_node_name)2706 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2707 NodeDef* tile;
2708 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &tile));
2709 if (!IsTile(*tile) || IsInPreserveSet(*tile)) {
2710 return OkStatus();
2711 }
2712
2713 if (NumNonControlOutputs(*tile, *ctx().node_map) != 1) {
2714 // Optimization is only worthwile when there is a single output from Tile.
2715 // Otherwise, we need to insert additional Reshape ops that can't be
2716 // easily removed.
2717 return OkStatus();
2718 }
2719
2720 NodeDef* reshape;
2721 TF_RETURN_IF_ERROR(GetInputNode(tile->input(0), &reshape));
2722 if (!IsReshape(*reshape) || IsInPreserveSet(*reshape)) {
2723 return OkStatus();
2724 }
2725
2726 NodeDef* multiples;
2727 TF_RETURN_IF_ERROR(GetInputNode(tile->input(1), &multiples));
2728
2729 NodeDef* shape;
2730 TF_RETURN_IF_ERROR(GetInputNode(reshape->input(1), &shape));
2731
2732 // Avoid optimizing the same nodes twice
2733 const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2734 const string new_reshape_name =
2735 OptimizedNodeName(scope_and_name, "Reshape");
2736 const string new_tile_name = OptimizedNodeName(scope_and_name, "Tile");
2737 const string new_multiples_name =
2738 OptimizedNodeName(scope_and_name, "Multiples");
2739 const string new_shape_name = OptimizedNodeName(scope_and_name, "Shape");
2740 if (ctx().node_map->NodeExists(new_reshape_name) ||
2741 ctx().node_map->NodeExists(new_tile_name) ||
2742 ctx().node_map->NodeExists(new_shape_name) ||
2743 ctx().node_map->NodeExists(new_multiples_name)) {
2744 return OkStatus();
2745 }
2746
2747 // Compuate updated multiples/shape values.
2748 AttrValue new_multiples_attr;
2749 if (!CreateUpdatedMultiplesProto(multiples,
2750 new_multiples_attr.mutable_tensor())) {
2751 return OkStatus();
2752 }
2753 AttrValue new_shape_attr;
2754 if (!CreateUpdatedShapeProto(shape, new_shape_attr.mutable_tensor())) {
2755 return OkStatus();
2756 }
2757
2758 // At this point the graph is validated and can be updated
2759 // Note: We can assume shape/multiples are DT_INT32 only at this point since
2760 // they're checked in CreateUpdated*Proto()
2761
2762 // 1. Create the constant nodes used by the new Reshape/Tile nodes
2763 NodeDef* new_multiples = AddEmptyNode(new_multiples_name);
2764 new_multiples->set_op("Const");
2765 SetDataTypeToAttr(DT_INT32, "dtype", new_multiples);
2766 new_multiples->mutable_attr()->insert({"value", new_multiples_attr});
2767 new_multiples->set_device(multiples->device());
2768
2769 NodeDef* new_shape = AddEmptyNode(new_shape_name);
2770 new_shape->set_op("Const");
2771 SetDataTypeToAttr(DT_INT32, "dtype", new_shape);
2772 new_shape->mutable_attr()->insert({"value", new_shape_attr});
2773 new_shape->set_device(shape->device());
2774
2775 // 2. Create the new Reshape/Tile nodes
2776 NodeDef* new_reshape = AddEmptyNode(new_reshape_name);
2777 CopyReshapeWithInput(reshape, new_reshape, /*input=*/reshape->input(0),
2778 /*shape=*/new_shape->name());
2779 NodeDef* new_tile = AddEmptyNode(new_tile_name);
2780 CopyTileWithInput(tile, new_tile, /*input=*/new_reshape->name(),
2781 /*multiples=*/new_multiples->name());
2782
2783 // 3. Update consumer of original Tile node and add control
2784 node->set_input(0, new_tile->name());
2785 ctx().node_map->UpdateInput(node->name(), tile->name(), new_tile->name());
2786
2787 ForwardControlDependencies(new_tile, {tile});
2788 ForwardControlDependencies(new_multiples, {multiples});
2789 ForwardControlDependencies(new_reshape, {reshape});
2790 ForwardControlDependencies(new_shape, {shape});
2791
2792 *simplified_node_name = node->name();
2793 return OkStatus();
2794 }
2795
2796 private:
CreateUpdatedMultiplesProto(const NodeDef * node,TensorProto * proto)2797 bool CreateUpdatedMultiplesProto(const NodeDef* node, TensorProto* proto) {
2798 Tensor multiples;
2799 if (!GetTensorFromConstNode(node->name(), &multiples)) {
2800 return false;
2801 }
2802
2803 // Dimensions should be [X, Y, N, 1, M, 1]
2804 if (multiples.dtype() != DT_INT32 || multiples.NumElements() != 6) {
2805 return false;
2806 }
2807
2808 const auto& multiples_values = multiples.flat<int32>();
2809 if (multiples_values(3) != 1 || multiples_values(5) != 1) {
2810 return false;
2811 }
2812
2813 // Convert to [X, Y, N, M]
2814 Tensor new_multiples(DT_INT32, {4});
2815 new_multiples.flat<int32>()(0) = multiples_values(0);
2816 new_multiples.flat<int32>()(1) = multiples_values(1);
2817 new_multiples.flat<int32>()(2) = multiples_values(2);
2818 new_multiples.flat<int32>()(3) = multiples_values(4);
2819
2820 new_multiples.AsProtoTensorContent(proto);
2821 return true;
2822 }
2823
CreateUpdatedShapeProto(const NodeDef * node,TensorProto * proto)2824 bool CreateUpdatedShapeProto(const NodeDef* node, TensorProto* proto) {
2825 Tensor shape;
2826 if (!GetTensorFromConstNode(node->name(), &shape)) {
2827 return false;
2828 }
2829
2830 // Dimensions should be [B, W, 1, H, 1, C]
2831 if (shape.dtype() != DT_INT32 || shape.NumElements() != 6) {
2832 return false;
2833 }
2834
2835 const auto& shape_values = shape.flat<int32>();
2836 if (shape_values(2) != 1 || shape_values(4) != 1) {
2837 return false;
2838 }
2839
2840 // Convert to [B, W, H, C]
2841 Tensor new_shape(DT_INT32, {4});
2842 new_shape.flat<int32>()(0) = shape_values(0);
2843 new_shape.flat<int32>()(1) = shape_values(1);
2844 new_shape.flat<int32>()(2) = shape_values(3);
2845 new_shape.flat<int32>()(3) = shape_values(5);
2846
2847 new_shape.AsProtoTensorContent(proto);
2848 return true;
2849 }
2850
CopyReshapeWithInput(const NodeDef * reshape,NodeDef * new_reshape,const string & input,const string & shape)2851 void CopyReshapeWithInput(const NodeDef* reshape, NodeDef* new_reshape,
2852 const string& input, const string& shape) {
2853 new_reshape->set_op("Reshape");
2854 new_reshape->set_device(reshape->device());
2855 SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "T"), "T", new_reshape);
2856 SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "Tshape"), "Tshape",
2857 new_reshape);
2858
2859 new_reshape->add_input(input);
2860 ctx().node_map->AddOutput(NodeName(input), new_reshape->name());
2861 new_reshape->add_input(shape);
2862 ctx().node_map->AddOutput(NodeName(shape), new_reshape->name());
2863
2864 AddToOptimizationQueue(new_reshape);
2865 }
2866
CopyTileWithInput(const NodeDef * tile,NodeDef * new_tile,const string & input,const string & multiples)2867 void CopyTileWithInput(const NodeDef* tile, NodeDef* new_tile,
2868 const string& input, const string& multiples) {
2869 new_tile->set_op("Tile");
2870 new_tile->set_device(tile->device());
2871 SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "T"), "T", new_tile);
2872 SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "Tmultiples"), "Tmultiples",
2873 new_tile);
2874
2875 new_tile->add_input(input);
2876 ctx().node_map->AddOutput(NodeName(input), new_tile->name());
2877 new_tile->add_input(multiples);
2878 ctx().node_map->AddOutput(NodeName(multiples), new_tile->name());
2879
2880 AddToOptimizationQueue(new_tile);
2881 }
2882 };
2883
2884 // Replace a sequence of Pack nodes with identical inputs with Tile
2885 // For example, given a Tensor X with shape (I,J,K)
2886 // Let P(x, n) = Pack([x, x], axis=n)
2887 //
2888 // P(P(X, 2), 1)
2889 // = Tile(Reshape(Tile(Reshape(x,
2890 // [I, J, 1, K]), [1, 1, 2, 1]),
2891 // [I, 1, J, 2, K]), [1, 2, 1, 1, 1]))
2892 // = Tile(Reshape(x,
2893 // [I, 1, J, 1, K]), [1, 2, 1, 2, 1])
2894 // = Reshape(Tile(x, [1, 2, 2]), [I, 2, J, 2, K])
2895 //
2896 // The outermost reshape is often redundant and can be removed in another pass
2897 class ReplacePackWithTileReshape : public ArithmeticOptimizerStage {
2898 public:
ReplacePackWithTileReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2899 explicit ReplacePackWithTileReshape(const GraphOptimizerContext& ctx,
2900 const ArithmeticOptimizerContext& ctx_ext)
2901 : ArithmeticOptimizerStage("ReplacePackWithTileReshape", ctx, ctx_ext) {}
2902 ~ReplacePackWithTileReshape() override = default;
2903
IsSupported(const NodeDef * node) const2904 bool IsSupported(const NodeDef* node) const override {
2905 return IsPack(*node) && NumNonControlInputs(*node) > 1 &&
2906 !IsInPreserveSet(*node);
2907 }
2908
TrySimplify(NodeDef * node,string * simplified_node_name)2909 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2910 // 1. traverse the chain of Pack ops to get the original input
2911 NodeDef* input = node;
2912 std::vector<const NodeDef*> chain;
2913 while (IsPack(*input) && NumNonControlInputs(*node) > 1 &&
2914 !IsInPreserveSet(*input)) {
2915 // Only pack operations with all identical inputs are supported
2916 if (!AllRegularInputsEqual(*input)) {
2917 break;
2918 }
2919 chain.push_back(input);
2920 TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &input));
2921 }
2922
2923 // Must be at least two Pack operations to consider for replacement
2924 if (chain.empty()) {
2925 return OkStatus();
2926 }
2927
2928 // Avoid optimizing the same node twice
2929 const NodeScopeAndName node_scope_and_name =
2930 ParseNodeScopeAndName(node->name());
2931 const string new_const_name =
2932 OptimizedNodeName(node_scope_and_name, "Multiples");
2933 const string new_tile_name = OptimizedNodeName(node_scope_and_name, "Tile");
2934 const string new_shape_name =
2935 OptimizedNodeName(node_scope_and_name, "Shape");
2936 const string new_reshape_name =
2937 OptimizedNodeName(node_scope_and_name, "Reshape");
2938 if (ctx().node_map->NodeExists(new_const_name) ||
2939 ctx().node_map->NodeExists(new_tile_name) ||
2940 ctx().node_map->NodeExists(new_shape_name) ||
2941 ctx().node_map->NodeExists(new_reshape_name)) {
2942 return OkStatus();
2943 }
2944
2945 // 2. Calculate the multiples and shape tensor using the chain
2946 const OpInfo::TensorProperties* input_props;
2947 TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props));
2948 const TensorShapeProto& input_shape = input_props->shape();
2949 if (!PartialTensorShape(input_shape).IsFullyDefined()) {
2950 return OkStatus();
2951 }
2952 Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()}));
2953 TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples));
2954
2955 const OpInfo::TensorProperties* output_props;
2956 TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props));
2957 const TensorShapeProto& output_shape = output_props->shape();
2958 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
2959 return OkStatus();
2960 }
2961 Tensor output_shape_tensor(DT_INT32,
2962 TensorShape({output_shape.dim_size()}));
2963 for (int i = 0; i < output_shape.dim_size(); ++i) {
2964 output_shape_tensor.flat<int32>()(i) = output_shape.dim(i).size();
2965 }
2966
2967 // 3. Create constant node with correct multiples value
2968 NodeDef* new_const_node = AddEmptyNode(new_const_name);
2969 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2970 new_const_node->name(), TensorValue(&multiples), new_const_node));
2971 new_const_node->set_device(node->device());
2972 MaybeAddControlInput(input->name(), new_const_node, ctx().optimized_graph,
2973 ctx().node_map);
2974 AddToOptimizationQueue(new_const_node);
2975
2976 // 4. Replace the Pack node with Tile(Const(N), input);
2977 DataType dtype = GetDataTypeFromAttr(*node, "T");
2978 NodeDef* new_tile_node = AddEmptyNode(new_tile_name);
2979 new_tile_node->set_op("Tile");
2980 new_tile_node->set_device(node->device());
2981 SetDataTypeToAttr(dtype, "T", new_tile_node);
2982 SetDataTypeToAttr(DT_INT32, "Tmultiples", new_tile_node);
2983 new_tile_node->add_input(input->name());
2984 ctx().node_map->AddOutput(input->name(), new_tile_node->name());
2985 new_tile_node->add_input(new_const_node->name());
2986 ctx().node_map->AddOutput(new_const_node->name(), new_tile_node->name());
2987
2988 // Tile inherits all control dependencies from the original pack chain
2989 ForwardControlDependencies(new_tile_node, chain);
2990 AddToOptimizationQueue(new_tile_node);
2991
2992 // 5. Add a new Reshape node to preserve the existing shape
2993 NodeDef* new_shape_node = AddEmptyNode(new_shape_name);
2994 TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2995 new_shape_node->name(), TensorValue(&output_shape_tensor),
2996 new_shape_node));
2997 new_shape_node->set_device(node->device());
2998 MaybeAddControlInput(input->name(), new_shape_node, ctx().optimized_graph,
2999 ctx().node_map);
3000 AddToOptimizationQueue(new_shape_node);
3001
3002 NodeDef* new_reshape_node = AddEmptyNode(new_reshape_name);
3003 new_reshape_node->set_op("Reshape");
3004 new_reshape_node->set_device(node->device());
3005 SetDataTypeToAttr(dtype, "T", new_reshape_node);
3006 SetDataTypeToAttr(DT_INT32, "Tshape", new_reshape_node);
3007 new_reshape_node->add_input(new_tile_node->name());
3008 ctx().node_map->AddOutput(new_tile_node->name(), new_reshape_node->name());
3009 new_reshape_node->add_input(new_shape_node->name());
3010 ctx().node_map->AddOutput(new_shape_node->name(), new_reshape_node->name());
3011
3012 *simplified_node_name = new_reshape_node->name();
3013
3014 return OkStatus();
3015 }
3016
3017 protected:
CalculateMultiplesFromChain(const std::vector<const NodeDef * > & chain,Tensor * multiples)3018 Status CalculateMultiplesFromChain(const std::vector<const NodeDef*>& chain,
3019 Tensor* multiples) {
3020 // Keep track of how the multiples correspond to each shape dimension.
3021 // For example, given Stack([x, x], axis=1) with rank(x) = 3, we start with
3022 // multiples=[1, 1, 1] , dims=[0, 1, 2]
3023 // After processing the stack op
3024 // multiples=[1, 2, 1] , dims=[0, 1, 1, 2]
3025 std::vector<int32> dims(multiples->NumElements());
3026 std::iota(dims.begin(), dims.end(), 0);
3027
3028 for (int i = 0; i < multiples->NumElements(); ++i) {
3029 multiples->flat<int32>()(i) = 1;
3030 }
3031
3032 for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
3033 AttrSlice attrs(**it);
3034 int64_t axis, n;
3035 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &axis));
3036 TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &n));
3037
3038 if (axis >= dims.size()) {
3039 // We don't handle the case where Pack is performed on the last axis,
3040 // e.g. Pack([x, x], axis=3) where rank(x) == 3
3041 return Status(error::OUT_OF_RANGE, "axis value out of range of dims");
3042 }
3043
3044 int64_t m = multiples->flat<int32>()(dims[axis]) * n;
3045 if (TF_PREDICT_FALSE(m > INT_MAX)) {
3046 return Status(error::OUT_OF_RANGE, "int32 overflow");
3047 }
3048 multiples->flat<int32>()(dims[axis]) = static_cast<int32>(m);
3049
3050 // Copy index from immediate right of inserted axis
3051 dims.insert(dims.begin() + axis, dims[axis]);
3052 }
3053
3054 return OkStatus();
3055 }
3056 };
3057
3058 // Simplify aggregation (e.g. AddN) nodes:
3059 //
3060 // 1. Discard aggregate nodes with a single input and no control dependencies.
3061 //
3062 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
3063 // deduping or other rewrites) so we can get rid of the sum entirely.
3064 //
3065 // The expression (using AddN as an example of an aggregate op):
3066 // AddN(x, x, x, ... ,x)
3067 // <-- N terms -->
3068 // can be rewritten to:
3069 // Mul(Const(N), x))
3070 //
3071 class SimplifyAggregation : public ArithmeticOptimizerStage {
3072 public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3073 explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
3074 const ArithmeticOptimizerContext& ctx_ext)
3075 : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
3076 ~SimplifyAggregation() override = default;
3077
IsSupported(const NodeDef * node) const3078 bool IsSupported(const NodeDef* node) const override {
3079 return IsAggregate(*node) && HasRegularInputs(*node) &&
3080 GetDataTypeFromAttr(*node, "T") !=
3081 DT_VARIANT; // TODO(b/119787146): Enable for variants.
3082 }
3083
TrySimplify(NodeDef * node,string * simplified_node_name)3084 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3085 // 1. Discard aggregate nodes with a single input and no control deps.
3086 if (node->input_size() == 1) {
3087 *simplified_node_name = node->input(0);
3088 return OkStatus();
3089 }
3090
3091 // 2. Rewrite aggregations of N >= 2 identical terms.
3092
3093 // All non-control inputs must be identical.
3094 bool all_equal = true;
3095 int num_inputs = 1;
3096 for (int i = 1; i < node->input_size(); ++i) {
3097 if (IsControlInput(node->input(i))) break;
3098 ++num_inputs;
3099 if (node->input(i) != node->input(0)) {
3100 all_equal = false;
3101 break;
3102 }
3103 }
3104 if (!all_equal) return OkStatus();
3105
3106 // And node should not be optimized earlier.
3107 const NodeScopeAndName node_scope_and_name =
3108 ParseNodeScopeAndName(node->name());
3109 const string optimized_const_name =
3110 OptimizedNodeName(node_scope_and_name, "Const");
3111 const string optimized_mul_name =
3112 OptimizedNodeName(node_scope_and_name, "Mul");
3113
3114 bool is_already_optimized =
3115 ctx().node_map->NodeExists(optimized_const_name) ||
3116 ctx().node_map->NodeExists(optimized_mul_name);
3117
3118 if (is_already_optimized) return OkStatus();
3119
3120 // At this point all preconditions are met, and we safely do the rewrite.
3121 VLOG(3) << "Simplify aggregation with identical inputs: node="
3122 << node->name() << " num_inputs=" << num_inputs;
3123
3124 // 1. Create constant node with value N.
3125 const auto type = GetDataTypeFromAttr(*node, "T");
3126 Tensor t(type, TensorShape({}));
3127 Status status = SetTensorValue(type, num_inputs, &t);
3128 if (!status.ok()) {
3129 return errors::Internal("Failed to create const node: ",
3130 status.error_message());
3131 }
3132
3133 TensorValue value(&t);
3134 NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
3135 status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
3136 new_const_node);
3137 if (!status.ok()) {
3138 return errors::Internal("Failed to create const node: ",
3139 status.error_message());
3140 }
3141 new_const_node->set_device(node->device());
3142 MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
3143 ctx().optimized_graph, ctx().node_map);
3144 AddToOptimizationQueue(new_const_node);
3145
3146 // 2. Replace the aggregate node with Mul(Const(N), x).
3147 NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
3148 new_mul_node->set_op("Mul");
3149 new_mul_node->set_device(node->device());
3150 SetDataTypeToAttr(type, "T", new_mul_node);
3151 new_mul_node->add_input(new_const_node->name());
3152 ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
3153 new_mul_node->add_input(node->input(0));
3154 ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
3155
3156 ForwardControlDependencies(new_mul_node, {node});
3157 *simplified_node_name = new_mul_node->name();
3158
3159 return OkStatus();
3160 }
3161 };
3162
3163 class ConvertPowStage : public ArithmeticOptimizerStage {
3164 public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3165 explicit ConvertPowStage(const GraphOptimizerContext& ctx,
3166 const ArithmeticOptimizerContext& ctx_ext)
3167 : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
3168
IsSupported(const NodeDef * node) const3169 bool IsSupported(const NodeDef* node) const override {
3170 return IsPow(*node) &&
3171 ctx().graph_properties->HasOutputProperties(node->name()) &&
3172 ctx().graph_properties->HasInputProperties(node->name());
3173 }
3174
TrySimplify(NodeDef * node,string * simplified_node_name)3175 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3176 Tensor pow;
3177 if (!GetTensorFromConstNode(node->input(1), &pow)) return OkStatus();
3178 complex128 prev, curr;
3179 for (int i = 0; i < pow.NumElements(); ++i) {
3180 if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
3181 // input data type is not supported by Pow. Skip.
3182 return OkStatus();
3183 }
3184 if (i != 0 && curr != prev) {
3185 // pow has different values on different elements. Skip.
3186 return OkStatus();
3187 }
3188 prev = curr;
3189 }
3190 NodeDef *x, *y;
3191 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
3192 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
3193 const auto& value_props =
3194 ctx().graph_properties->GetInputProperties(node->name())[0];
3195 const TensorShapeProto& output_shape =
3196 ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
3197 if (curr == complex128(2, 0)) {
3198 node->set_op("Square");
3199 node->set_input(1, AsControlDependency(y->name()));
3200 AddToOptimizationQueue(node);
3201 AddToOptimizationQueue(y);
3202 } else if (curr == complex128(3, 0)) {
3203 // TODO(courbet): Use 'Cube' when it's added to TF ops.
3204 if (NodeIsOnCpu(*node)) {
3205 // We create an inner square node: inner_square = square(x)
3206 const NodeScopeAndName scope_and_name =
3207 ParseNodeScopeAndName(node->name());
3208 const string inner_square_name =
3209 OptimizedNodeName(scope_and_name, "_inner");
3210 NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
3211 if (inner_square_node == nullptr) {
3212 inner_square_node = AddCopyNode(inner_square_name, node);
3213 inner_square_node->set_op("Square");
3214 inner_square_node->mutable_input()->RemoveLast();
3215 }
3216 ctx().node_map->AddOutput(x->name(), inner_square_node->name());
3217 // We modify `node`: node = mul(x, inner_square);
3218 node->set_op("Mul");
3219 node->set_input(1, inner_square_node->name());
3220 node->add_input(AsControlDependency(y->name()));
3221
3222 AddToOptimizationQueue(node);
3223 AddToOptimizationQueue(inner_square_node);
3224 AddToOptimizationQueue(y);
3225 }
3226 } else if (curr == complex128(1, 0) &&
3227 ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
3228 // Pow could be used to broadcast, so make sure the shapes of the two
3229 // arguments are identical before replacing Pow with Identity.
3230 node->set_op("Identity");
3231 node->set_input(1, AsControlDependency(y->name()));
3232 AddToOptimizationQueue(node);
3233 AddToOptimizationQueue(y);
3234 } else if (curr == complex128(0.5, 0)) {
3235 node->set_op("Sqrt");
3236 node->set_input(1, AsControlDependency(y->name()));
3237 AddToOptimizationQueue(node);
3238 AddToOptimizationQueue(y);
3239 } else if (curr == complex128(0, 0) &&
3240 ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
3241 PartialTensorShape(output_shape).IsFullyDefined()) {
3242 const auto dtype = node->attr().at("T").type();
3243 Tensor ones(dtype, output_shape);
3244 for (int i = 0; i < ones.NumElements(); ++i) {
3245 TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
3246 }
3247 node->set_op("Const");
3248 (*node->mutable_attr())["dtype"].set_type(dtype);
3249 node->mutable_attr()->erase("T");
3250 ones.AsProtoTensorContent(
3251 (*node->mutable_attr())["value"].mutable_tensor());
3252 node->set_input(0, AsControlDependency(x->name()));
3253 node->set_input(1, AsControlDependency(y->name()));
3254 AddToOptimizationQueue(node);
3255 AddToOptimizationQueue(x);
3256 AddToOptimizationQueue(y);
3257 } else if (curr == complex128(-0.5, 0)) {
3258 node->set_op("Rsqrt");
3259 node->set_input(1, AsControlDependency(y->name()));
3260 AddToOptimizationQueue(node);
3261 AddToOptimizationQueue(y);
3262 } else if (curr == complex128(-1, 0)) {
3263 node->set_op("Reciprocal");
3264 node->set_input(1, AsControlDependency(y->name()));
3265 AddToOptimizationQueue(node);
3266 AddToOptimizationQueue(y);
3267 }
3268 return OkStatus();
3269 }
3270
3271 private:
SetElementToOne(int i,Tensor * t)3272 Status SetElementToOne(int i, Tensor* t) {
3273 switch (t->dtype()) {
3274 case DT_INT32:
3275 t->flat<int32>()(i) = 1;
3276 return OkStatus();
3277 case DT_INT64:
3278 t->flat<int64_t>()(i) = 1L;
3279 return OkStatus();
3280 case DT_FLOAT:
3281 t->flat<float>()(i) = 1.0f;
3282 return OkStatus();
3283 case DT_DOUBLE:
3284 t->flat<double>()(i) = 1.0;
3285 return OkStatus();
3286 case DT_COMPLEX64:
3287 t->flat<complex64>()(i) = complex64(1);
3288 return OkStatus();
3289 case DT_COMPLEX128:
3290 t->flat<complex128>()(i) = complex128(1);
3291 return OkStatus();
3292 default:
3293 return errors::InvalidArgument("Invalid data type: ", t->dtype());
3294 }
3295 }
3296 };
3297
3298 class ConvertLog1pStage : public ArithmeticOptimizerStage {
3299 public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3300 explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
3301 const ArithmeticOptimizerContext& ctx_ext)
3302 : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
3303 ~ConvertLog1pStage() override = default;
3304
IsSupported(const NodeDef * node) const3305 bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
3306
TrySimplify(NodeDef * node,string * simplified_node_name)3307 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3308 NodeDef* input;
3309 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
3310 if (!IsAdd(*input)) {
3311 return OkStatus();
3312 }
3313
3314 if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
3315 return OkStatus();
3316 }
3317
3318 bool modified = false;
3319 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
3320 if (!modified) {
3321 TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
3322 }
3323 if (modified) {
3324 *simplified_node_name = node->name();
3325 }
3326 return OkStatus();
3327 }
3328
3329 private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)3330 Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
3331 bool* modified) {
3332 const auto& t =
3333 ctx().graph_properties->GetInputProperties(add_node->name())[i];
3334 const auto& c =
3335 ctx().graph_properties->GetInputProperties(add_node->name())[j];
3336 for (int k = 0; k < c.shape().dim_size(); ++k) {
3337 // Skip if c shape is not fully determined.
3338 if (c.shape().dim(k).size() < 0) {
3339 return OkStatus();
3340 }
3341 }
3342 TensorShapeProto broadcast_shape;
3343 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3344 return OkStatus();
3345 }
3346 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3347 // skip if the non-constant tensor doesn't have the same shape after
3348 // broadcast.
3349 return OkStatus();
3350 }
3351 Tensor constant;
3352 if (GetTensorFromConstNode(add_node->input(j), &constant)) {
3353 complex128 element;
3354 // TODO(rmlarsen): Refactor the more general IsOnes from
3355 // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
3356 // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
3357 // and Neg(-1) to 1.
3358 for (int k = 0; k < constant.NumElements(); ++k) {
3359 if (!GetElementUnexhaustive(constant, k,
3360 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3361 DT_COMPLEX64, DT_COMPLEX128},
3362 &element)) {
3363 // input data type is not supported by log1p. Skip.
3364 return OkStatus();
3365 }
3366 if (element != complex128(1)) {
3367 // current element is not 1. Skip.
3368 return OkStatus();
3369 }
3370 }
3371 NodeDef *x, *y;
3372 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
3373 TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
3374 node->set_op("Log1p");
3375 node->set_input(0, add_node->input(i));
3376 node->add_input(AsControlDependency(y->name()));
3377 ForwardControlDependencies(node, {add_node});
3378
3379 AddToOptimizationQueue(node);
3380 AddToOptimizationQueue(add_node);
3381 AddToOptimizationQueue(x);
3382 AddToOptimizationQueue(y);
3383 *modified = true;
3384 }
3385 return OkStatus();
3386 }
3387 };
3388
3389 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
3390 public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3391 explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
3392 const ArithmeticOptimizerContext& ctx_ext)
3393 : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
3394 ~ConvertExpm1Stage() override = default;
3395
IsSupported(const NodeDef * node) const3396 bool IsSupported(const NodeDef* node) const override {
3397 if (!IsSub(*node)) return false;
3398
3399 NodeDef* input;
3400 if (!GetInputNode(node->input(0), &input).ok()) return false;
3401
3402 return IsExp(*input);
3403 }
3404
TrySimplify(NodeDef * node,string * simplified_node_name)3405 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3406 if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
3407 return OkStatus();
3408 }
3409 const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
3410 const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
3411 TensorShapeProto broadcast_shape;
3412 if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3413 return OkStatus();
3414 }
3415 if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3416 // skip if the non-constant tensor doesn't have the same shape after
3417 // broadcast.
3418 return OkStatus();
3419 }
3420 Tensor constant;
3421 if (!GetTensorFromConstNode(node->input(1), &constant)) return OkStatus();
3422 // TODO(rmlarsen): Use the more general IsOnes helper here.
3423 complex128 element;
3424 for (int k = 0; k < constant.NumElements(); ++k) {
3425 if (!GetElementUnexhaustive(constant, k,
3426 {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3427 DT_COMPLEX64, DT_COMPLEX128},
3428 &element)) {
3429 // input data type is not supported by expm1. Skip.
3430 return OkStatus();
3431 }
3432 if (element != complex128(1)) {
3433 // current element is not 1. Skip.
3434 return OkStatus();
3435 }
3436 }
3437 NodeDef* exp;
3438 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
3439 NodeDef *exp_input, *ones;
3440 TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
3441 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
3442 node->set_op("Expm1");
3443 node->set_input(0, exp->input(0));
3444 node->set_input(1, AsControlDependency(ones->name()));
3445 ForwardControlDependencies(node, {exp});
3446
3447 AddToOptimizationQueue(node);
3448 AddToOptimizationQueue(exp);
3449 AddToOptimizationQueue(exp_input);
3450 AddToOptimizationQueue(ones);
3451 *simplified_node_name = node->name();
3452 return OkStatus();
3453 }
3454 };
3455
3456 // Performs conversions like:
3457 // Max(Sqrt(x)) => Sqrt(Max(x))
3458 // Checks for a max/min reduction over element-wise monotonic functions, such
3459 // as Sqrt, Sigmoid, Tanh, etc.
3460 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
3461 public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3462 explicit OptimizeMaxOrMinOfMonotonicStage(
3463 const GraphOptimizerContext& ctx,
3464 const ArithmeticOptimizerContext& ctx_ext)
3465 : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
3466 ctx_ext) {}
3467 ~OptimizeMaxOrMinOfMonotonicStage() override = default;
3468
IsSupported(const NodeDef * node) const3469 bool IsSupported(const NodeDef* node) const override {
3470 // Running on (Unsorted)SegmentMax(Min) can cause issues on empty segments.
3471 return IsMax(*node) || IsMin(*node) || IsAnyMaxPool(*node) ||
3472 IsArgMax(*node) || IsArgMin(*node);
3473 }
3474
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3475 Status TrySimplify(NodeDef* reduction_node,
3476 string* simplified_node_name) override {
3477 if (IsInPreserveSet(*reduction_node)) {
3478 return OkStatus();
3479 }
3480
3481 NodeDef* inner_function;
3482 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
3483
3484 NodeDef* inner_function_input = nullptr;
3485 if (inner_function->input_size() > 0) {
3486 TF_RETURN_IF_ERROR(
3487 GetInputNode(inner_function->input(0), &inner_function_input));
3488 }
3489
3490 // Optimize only if:
3491 // 0. inner_function is not in the preserve set,
3492 // 1. inner_function's Op is element-wise monotonic
3493 // 2. inner_function's output is not being consumed elsewhere.
3494 // 3. is monotonic increasing if reduction_node is a pooling operation
3495 // since we don't have MinPool operations.
3496 // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
3497 // or BiasAdd. This pattern will be fused later by remapper.
3498 auto can_be_fused_by_remapper = [](const NodeDef& consumer,
3499 const NodeDef& producer) -> bool {
3500 if (IsRelu(consumer) || IsRelu6(consumer)) {
3501 if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
3502 return true;
3503 }
3504 }
3505 return false;
3506 };
3507 bool is_non_decreasing = false;
3508 if (!IsInPreserveSet(*inner_function) &&
3509 IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
3510 ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
3511 (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
3512 !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
3513 // Swap the first inputs of the inner function Op & the reduction Op.
3514 NodeDef* inner_input;
3515 TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
3516 reduction_node->set_input(0, inner_input->name());
3517 ctx().node_map->UpdateInput(reduction_node->name(),
3518 inner_function->name(), inner_input->name());
3519 inner_function->set_input(0, reduction_node->name());
3520 TF_RETURN_IF_ERROR(
3521 UpdateConsumers(reduction_node, inner_function->name()));
3522 ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
3523 reduction_node->name());
3524 if (!is_non_decreasing) {
3525 // Flip Min<->Max if the function is non-increasing, e.g.
3526 // Max(Neg(x)) = Neg(Min(x)).
3527 const string opposite = FlipMinMax(*reduction_node);
3528 reduction_node->set_op(opposite);
3529 }
3530
3531 if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
3532 // ArgMax(Sqrt(x)) = ArgMax(x)
3533 inner_function->set_op("Identity");
3534 }
3535
3536 AddToOptimizationQueue(reduction_node);
3537 AddToOptimizationQueue(inner_function);
3538 AddToOptimizationQueue(inner_input);
3539 }
3540 return OkStatus();
3541 }
3542
3543 private:
FlipMinMax(const NodeDef & node)3544 string FlipMinMax(const NodeDef& node) {
3545 const string& op = node.op();
3546 if (IsAnyMax(node) || IsArgMax(node)) {
3547 return str_util::StringReplace(op, "Max", "Min", false);
3548 } else {
3549 return str_util::StringReplace(op, "Min", "Max", false);
3550 }
3551 }
3552 };
3553
3554 // Replace a chain of type&shape preserving unary ops with a
3555 // '_UnaryOpsComposition' node.
3556 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
3557 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
3558 class UnaryOpsComposition : public ArithmeticOptimizerStage {
3559 public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3560 explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
3561 const ArithmeticOptimizerContext& ctx_ext)
3562 : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
3563 // WARN: This should be consistent with unary_ops_composition.cc.
3564 // clang-format off
3565 supported_ops_ = {// Ops defined via Eigen scalar ops.
3566 {"Abs", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3567 {"Acos", {DT_FLOAT, DT_DOUBLE}},
3568 {"Acosh", {DT_FLOAT, DT_DOUBLE}},
3569 {"Asin", {DT_FLOAT, DT_DOUBLE}},
3570 {"Asinh", {DT_FLOAT, DT_DOUBLE}},
3571 {"Atan", {DT_FLOAT, DT_DOUBLE}},
3572 {"Atanh", {DT_FLOAT, DT_DOUBLE}},
3573 {"Ceil", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3574 {"Cos", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3575 {"Cosh", {DT_FLOAT, DT_DOUBLE}},
3576 {"Expm1", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3577 {"Exp", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3578 {"Floor", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3579 {"Inv", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3580 {"Log", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3581 {"Log1p", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3582 {"Neg", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3583 {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3584 {"Rint", {DT_FLOAT, DT_DOUBLE}},
3585 {"Round", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3586 {"Rsqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3587 {"Sigmoid", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3588 {"Sin", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3589 {"Sinh", {DT_FLOAT, DT_DOUBLE}},
3590 {"Sqrt", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3591 {"Square", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3592 {"Tan", {DT_FLOAT, DT_DOUBLE}},
3593 {"Tanh", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3594 // Additional ops that are not part of the Eigen.
3595 {"Elu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3596 {"Relu", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3597 {"Relu6", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3598 {"Selu", {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3599 // clang-format on
3600 }
3601 ~UnaryOpsComposition() override = default;
3602
IsSupported(const NodeDef * node) const3603 bool IsSupported(const NodeDef* node) const override {
3604 return CanOptimize(*node) &&
3605 // Check that this node was not already a root of a fused chain. If
3606 // graph optimization runs twice without pruning in between,
3607 // fused_nodes_ will not have this information.
3608 !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3609 }
3610
TrySimplify(NodeDef * root,string * simplified_node_name)3611 Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3612 TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3613 DataType dtype = root->attr().at("T").type();
3614
3615 // Keep a trace of all supported input nodes that can be fused together.
3616 std::vector<string> op_nodes = {root->name()};
3617 std::vector<string> op_names = {root->op()};
3618
3619 // Check if we should follow input(0) while building an op composition.
3620 const auto predicate_fn = [&](const NodeDef& input) {
3621 if (input.name() == root->name()) return true;
3622
3623 bool follow_input_node =
3624 dtype == GetDataTypeFromAttr(input, "T") &&
3625 NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3626 CanOptimize(input);
3627
3628 if (follow_input_node) {
3629 op_nodes.push_back(input.name());
3630 op_names.push_back(input.op());
3631 }
3632
3633 return follow_input_node;
3634 };
3635
3636 NodeDef* last_op = GetTailOfChain(
3637 *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3638
3639 // We were not able to find a chain that can be replaced.
3640 if (op_names.size() == 1) return OkStatus();
3641
3642 // Do not add fused nodes to any other chain.
3643 std::for_each(op_nodes.begin(), op_nodes.end(),
3644 [this](const string& name) { AddToFusedNodes(name); });
3645
3646 // Reverse the trace to get correct composition computation order.
3647 std::reverse(op_names.begin(), op_names.end());
3648
3649 VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3650 << absl::StrJoin(op_names, ", ") << "]";
3651
3652 NodeDef* composition_node = ctx().optimized_graph->add_node();
3653 composition_node->set_name(OptimizedNodeName(*root));
3654 composition_node->set_op("_UnaryOpsComposition");
3655 composition_node->add_input(last_op->input(0));
3656 composition_node->set_device(root->device());
3657
3658 auto attr = composition_node->mutable_attr();
3659 SetAttrValue(dtype, &(*attr)["T"]);
3660 SetAttrValue(op_names, &(*attr)["op_names"]);
3661
3662 ctx().node_map->AddNode(composition_node->name(), composition_node);
3663 ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3664 composition_node->name());
3665
3666 *simplified_node_name = composition_node->name();
3667
3668 return OkStatus();
3669 }
3670
3671 private:
CanOptimize(const NodeDef & node) const3672 bool CanOptimize(const NodeDef& node) const {
3673 DataType dtype = GetDataTypeFromAttr(node, "T");
3674 if (!IsSupported(node.op(), dtype)) {
3675 return false;
3676 }
3677 if (IsInPreserveSet(node)) {
3678 return false;
3679 }
3680 if (!NodeIsOnCpu(node)) {
3681 return false;
3682 }
3683 if (NodeIsAlreadyFused(node)) {
3684 return false;
3685 }
3686 return !(IsDrivenByControlDependency(node) ||
3687 DrivesControlDependency(node));
3688 }
3689
NodeIsAlreadyFused(const NodeDef & node) const3690 bool NodeIsAlreadyFused(const NodeDef& node) const {
3691 return fused_nodes_.count(node.name()) > 0;
3692 }
3693
OptimizedNodeName(const NodeDef & node) const3694 string OptimizedNodeName(const NodeDef& node) const {
3695 return strings::StrCat(node.name(), "/unary_ops_composition");
3696 }
3697
AddToFusedNodes(const string & name)3698 void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3699
3700 // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3701 bool IsSupported(const string& op_name, DataType dtype) const {
3702 const auto it = supported_ops_.find(op_name);
3703 return it != supported_ops_.end() && it->second.count(dtype) > 0;
3704 }
3705
3706 std::unordered_map<string, std::set<DataType>> supported_ops_;
3707 std::unordered_set<string> fused_nodes_;
3708 };
3709
3710 // Replace operations of the form:
3711 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3712 // with
3713 // a_i
3714 // when the strided slice index `i` is applied in the k'th axis.
3715 //
3716 // Similarly, replace operations of the form:
3717 // x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3718 // with
3719 // expand_dims(a_i, axis=k)
3720 // where the slice operator can be StridedSlice or Slice.
3721 //
3722 // TODO(ebrevdo): Extend to also replace operations of the form
3723 // concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3724 // with
3725 // a_i,
3726 // when
3727 // s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3728 // and slicing is in the k'th axis.
3729 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3730 public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3731 explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3732 const ArithmeticOptimizerContext& ctx_ext)
3733 : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3734 ctx_ext) {}
3735 ~RemoveStackSliceSameAxis() override = default;
3736
IsSupported(const NodeDef * node) const3737 bool IsSupported(const NodeDef* node) const override {
3738 return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3739 }
3740
TrySimplify(NodeDef * node,string * simplified_node_name)3741 Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3742 // *node is a StridedSlice NodeDef.
3743 NodeDef* pack;
3744
3745 // Get the input and see if it's a Pack op.
3746 TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3747 if (!IsPack(*pack)) return OkStatus();
3748
3749 bool return_early;
3750 PartialTensorShape pack_output_shape;
3751 int pack_axis;
3752 TF_RETURN_IF_ERROR(
3753 CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3754 if (return_early) return OkStatus();
3755
3756 int64_t slice_start_value;
3757 bool found;
3758 bool must_expand_dims;
3759 TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3760 &slice_start_value, &found,
3761 &must_expand_dims));
3762 if (!found) return OkStatus();
3763
3764 return RewriteGraph(node, pack, slice_start_value, pack_axis,
3765 must_expand_dims, simplified_node_name);
3766 }
3767
3768 protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3769 Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3770 PartialTensorShape* pack_output_shape, int* pack_axis,
3771 bool* return_early) {
3772 *return_early = true;
3773 TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3774
3775 *pack_axis = pack->attr().at("axis").i();
3776 auto slice_properties =
3777 ctx().graph_properties->GetInputProperties(node->name());
3778 if (slice_properties.empty() ||
3779 slice_properties[0].shape().unknown_rank()) {
3780 return OkStatus();
3781 }
3782 *pack_output_shape = slice_properties[0].shape();
3783 const int pack_output_rank = pack_output_shape->dims();
3784 if (*pack_axis < 0) {
3785 *pack_axis += pack_output_rank;
3786 }
3787 if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3788 return errors::InvalidArgument(
3789 "Pack node (", pack->name(),
3790 ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3791 }
3792 *return_early = false;
3793 return OkStatus();
3794 }
3795
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found,bool * must_expand_dims)3796 Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3797 const PartialTensorShape& pack_output_shape,
3798 int pack_axis, int64_t* slice_start_value, bool* found,
3799 bool* must_expand_dims) {
3800 *found = false;
3801 if (IsSlice(*node)) {
3802 *must_expand_dims = true;
3803 return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3804 slice_start_value, found);
3805 } else {
3806 return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3807 slice_start_value, found, must_expand_dims);
3808 }
3809 }
3810
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found)3811 Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3812 const PartialTensorShape& pack_output_shape,
3813 int pack_axis, int64_t* slice_start_value,
3814 bool* found) {
3815 NodeDef* slice_begin;
3816 NodeDef* slice_size;
3817 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3818 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3819 for (const auto* n : {slice_begin, slice_size}) {
3820 if (!IsReallyConstant(*n)) return OkStatus();
3821 }
3822
3823 Tensor slice_begin_t;
3824 Tensor slice_size_t;
3825 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3826 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3827 return OkStatus();
3828 }
3829 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3830 if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3831 return OkStatus();
3832 }
3833
3834 auto copy_tensor_values_to_vector =
3835 [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3836 if (t.dtype() == DT_INT32) {
3837 auto t_flat = t.flat<int32>();
3838 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3839 } else if (t.dtype() == DT_INT64) {
3840 auto t_flat = t.flat<int64_t>();
3841 vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3842 } else {
3843 return errors::InvalidArgument("Node ", node->name(),
3844 " has invalid type for Index attr: ",
3845 DataTypeString(t.dtype()));
3846 }
3847 return OkStatus();
3848 };
3849
3850 gtl::InlinedVector<int64_t, 4> slice_begin_vec;
3851 gtl::InlinedVector<int64_t, 4> slice_size_vec;
3852 TF_RETURN_IF_ERROR(
3853 copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3854 TF_RETURN_IF_ERROR(
3855 copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3856
3857 if (slice_begin_vec.size() != slice_size_vec.size()) {
3858 return errors::InvalidArgument("Node ", node->name(),
3859 " has mismatched lengths for begin (",
3860 slice_begin_vec.size(), ") and size (",
3861 slice_size_vec.size(), ") vectors.");
3862 }
3863 int slice_begin_vec_size = slice_begin_vec.size();
3864 if (!pack_output_shape.unknown_rank() &&
3865 slice_begin_vec_size != pack_output_shape.dims()) {
3866 return OkStatus();
3867 }
3868 if (pack_axis >= slice_begin_vec_size) {
3869 return errors::InvalidArgument(
3870 "Input to node ", node->name(), " had pack_axis ", pack_axis,
3871 " but rank was ", slice_begin_vec_size, ".");
3872 }
3873
3874 *slice_start_value = slice_begin_vec[pack_axis];
3875 if (slice_size_vec[pack_axis] != 1) {
3876 // Not slicing a single value out.
3877 return OkStatus();
3878 }
3879
3880 for (int i = 0; i < slice_begin_vec_size; ++i) {
3881 if (i != pack_axis) {
3882 if (slice_begin_vec[i] != 0 ||
3883 !(slice_size_vec[i] == -1 ||
3884 slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3885 // Not slicing on the same axis as the Pack op.
3886 return OkStatus();
3887 }
3888 }
3889 }
3890
3891 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3892 return errors::InvalidArgument(
3893 "Node ", node->name(), " requested invalid slice index ",
3894 *slice_start_value, " on axis ", pack_axis,
3895 " from tensor of shape: ", pack_output_shape.DebugString());
3896 }
3897
3898 *found = true; // slice_start_value is valid.
3899 return OkStatus();
3900 }
3901
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found,bool * must_expand_dims)3902 Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3903 const PartialTensorShape& pack_output_shape,
3904 int pack_axis, int64_t* slice_start_value,
3905 bool* found, bool* must_expand_dims) {
3906 TF_RETURN_IF_ERROR(
3907 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3908 "new_axis_mask", "shrink_axis_mask"}));
3909
3910 const int begin_mask = node->attr().at("begin_mask").i();
3911 const int end_mask = node->attr().at("end_mask").i();
3912 const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3913 const int new_axis_mask = node->attr().at("new_axis_mask").i();
3914 const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3915
3916 // Check that the StridedSlice is one of these at pack_axis:
3917 // [..., i, ...]
3918 // [..., i:i+1, ...]
3919 // [..., :1, ...]
3920 // [..., -1:, ...]
3921 /// [..., s_{pack_axis}-1:, ...]
3922 NodeDef* slice_begin;
3923 NodeDef* slice_end;
3924 NodeDef* slice_strides;
3925 TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3926 TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3927 TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3928
3929 for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3930 if (!IsReallyConstant(*n)) return OkStatus();
3931 }
3932
3933 Tensor slice_begin_t;
3934 Tensor slice_end_t;
3935 Tensor slice_strides_t;
3936
3937 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3938 if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3939 return OkStatus();
3940 }
3941 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3942 if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3943 return OkStatus();
3944 }
3945 TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3946 if (!slice_strides_t.FromProto(
3947 slice_strides->attr().at("value").tensor())) {
3948 return OkStatus();
3949 }
3950 TensorShape processing_shape;
3951 TensorShape final_shape;
3952 bool is_identity;
3953 bool is_simple_slice;
3954 bool slice_dim0;
3955 gtl::InlinedVector<int64_t, 4> slice_begin_vec;
3956 gtl::InlinedVector<int64_t, 4> slice_end_vec;
3957 gtl::InlinedVector<int64_t, 4> slice_strides_vec;
3958 TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3959 &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3960 begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3961 &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3962 &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3963
3964 if (!is_simple_slice) return OkStatus();
3965
3966 int begin_index = -1;
3967 int64_t begin_value = 0;
3968 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3969 const int64_t v = slice_begin_vec[i];
3970 if (v != 0) {
3971 if (begin_index != -1) {
3972 // At least two start values that are nonzero.
3973 return OkStatus();
3974 }
3975 begin_index = i;
3976 begin_value = v;
3977 }
3978 }
3979
3980 int end_index = -1;
3981 int64_t end_value = 0;
3982 for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3983 const int64_t v = slice_end_vec[i];
3984 if (v != pack_output_shape.dim_size(i)) {
3985 if (end_index != -1) {
3986 // At least two end values that are nonzero.
3987 return OkStatus();
3988 }
3989 end_index = i;
3990 end_value = v;
3991 }
3992 }
3993
3994 if (begin_index == -1 && end_index == -1) return OkStatus();
3995 if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3996 // Somehow received different axes for begin/end slicing
3997 return OkStatus();
3998 }
3999 const int slice_axis = (begin_index == -1) ? end_index : begin_index;
4000 if (slice_axis != pack_axis) {
4001 // Not slicing on the same axis as the Pack op.
4002 return OkStatus();
4003 }
4004 *slice_start_value = (begin_index == -1) ? 0 : begin_value;
4005 const int64_t slice_end_value =
4006 (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
4007 if (slice_end_value != *slice_start_value + 1) {
4008 // Not slicing a single value out.
4009 return OkStatus();
4010 }
4011
4012 if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
4013 return errors::InvalidArgument(
4014 "Node ", node->name(), " requested invalid slice index ",
4015 *slice_start_value, " on axis ", slice_axis,
4016 " from tensor of shape: ", pack_output_shape.DebugString());
4017 }
4018
4019 if (shrink_axis_mask == 0) {
4020 *must_expand_dims = true;
4021 } else if (shrink_axis_mask == (1 << slice_axis)) {
4022 *must_expand_dims = false;
4023 } else {
4024 // Shrinking on a different axis from the one that we are slicing on.
4025 return OkStatus();
4026 }
4027
4028 *found = true; // slice_start_value is valid.
4029 return OkStatus();
4030 }
4031
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64_t slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)4032 Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
4033 int64_t slice_start_value, int pack_axis,
4034 bool must_expand_dims, string* simplified_node_name) {
4035 const string& input_slice = pack->input(slice_start_value);
4036
4037 const OpInfo::TensorProperties* input_slice_properties;
4038 TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
4039 &input_slice_properties));
4040 PartialTensorShape input_slice_shape(input_slice_properties->shape());
4041
4042 const OpInfo::TensorProperties* output_properties;
4043 TF_RETURN_IF_ERROR(GetTensorProperties(
4044 strings::StrCat(node->name(), ":", 0), &output_properties));
4045 PartialTensorShape output_shape(output_properties->shape());
4046 NodeDef* output =
4047 AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
4048 if (!must_expand_dims) {
4049 output->set_op("Identity");
4050 output->set_device(node->device());
4051 SetDataTypeToAttr(output_properties->dtype(), "T", output);
4052 output->add_input(input_slice);
4053 } else {
4054 NodeDef* axis = AddEmptyNode(
4055 OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
4056 axis->set_op("Const");
4057 axis->set_device(node->device());
4058 // We need to add a control edge from input slice to guarantee that axis
4059 // constant will be executed in the same frame as `input_slice`, otherwise
4060 // ExpandDims might have mismatched input frames.
4061 axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
4062 auto axis_attr = axis->mutable_attr();
4063 SetDataTypeToAttr(DT_INT32, "dtype", axis);
4064 auto* axis_t = (*axis_attr)["value"].mutable_tensor();
4065 axis_t->set_dtype(DT_INT32);
4066 axis_t->add_int_val(pack_axis);
4067 AddToOptimizationQueue(axis);
4068 output->set_op("ExpandDims");
4069 output->set_device(node->device());
4070 SetDataTypeToAttr(output_properties->dtype(), "T", output);
4071 SetDataTypeToAttr(DT_INT32, "Tdim", output);
4072 output->add_input(input_slice);
4073 output->add_input(axis->name());
4074 }
4075
4076 // Copy dependencies over.
4077 ForwardControlDependencies(output, {node, pack});
4078 AddToOptimizationQueue(output);
4079 *simplified_node_name = output->name();
4080
4081 return OkStatus();
4082 }
4083 };
4084
4085 // Eliminates unnecessary copies during sparse embedding lookup operations.
4086 //
4087 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
4088 // generates code of the form:
4089 //
4090 // embeddings = <a 2D Tensor>
4091 // sparse_ids = <a tf.int64 SparseTensor>
4092 // segment_ids = sparse_ids.indices[:, 0]
4093 // ids, idx = tf.unique(sparse_ids.values)
4094 // gathered_rows = tf.gather(params, ids)
4095 // result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
4096 //
4097 // In this case, all of the work in `tf.unique()` and `tf.gather()`
4098 // can be avoided by passing the full embeddings to
4099 // `tf.sparse.segment_<combiner>()` and performing the same amount of
4100 // computation (but fewer copies and allocations) as follows:
4101 //
4102 // embeddings = <a 2D Tensor>
4103 // sparse_ids = <a tf.int64 SparseTensor>
4104 // segment_ids = sparse_ids.indices[:, 0]
4105 // result = tf.sparse.segment_<combiner>(
4106 // embeddings, sparse_ids.values, segment_ids)
4107 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
4108 public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4109 explicit SimplifyEmbeddingLookupStage(
4110 const GraphOptimizerContext& ctx,
4111 const ArithmeticOptimizerContext& ctx_ext)
4112 : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
4113 }
4114 ~SimplifyEmbeddingLookupStage() override = default;
4115
IsSupported(const NodeDef * node) const4116 bool IsSupported(const NodeDef* node) const override {
4117 return IsAnySparseSegmentReduction(*node);
4118 }
4119
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4120 Status TrySimplify(NodeDef* reduction_node,
4121 string* simplified_node_name) override {
4122 if (IsInPreserveSet(*reduction_node)) return OkStatus();
4123
4124 // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
4125 // axis.
4126 NodeDef* gather_node = nullptr;
4127 TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
4128 if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
4129 gather_node->device() != reduction_node->device())
4130 return OkStatus();
4131 if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
4132 return OkStatus();
4133
4134 // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
4135 // axis.
4136 NodeDef* unique_node = nullptr;
4137 TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
4138 if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
4139 unique_node->device() != gather_node->device())
4140 return OkStatus();
4141 if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
4142 return OkStatus();
4143
4144 DataType unique_element_type;
4145 TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
4146
4147 // Input 1 (indices) of the reduction node must be output 1 of the unique
4148 // node.
4149 const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
4150 if (idx_tensor != TensorId(unique_node->name(), 1)) return OkStatus();
4151
4152 // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
4153 // node.
4154 reduction_node->set_input(1, unique_node->input(0));
4155 ctx().node_map->UpdateInput(reduction_node->name(),
4156 reduction_node->input(1),
4157 unique_node->input(0));
4158 SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
4159
4160 // Input 0 (data) of the reduction node becomes input 1 (params) of the
4161 // gather node.
4162 const OpInfo::TensorProperties* gather_input_properties;
4163 TF_RETURN_IF_ERROR(
4164 GetTensorProperties(gather_node->input(0), &gather_input_properties));
4165 if (gather_input_properties->dtype() == DT_RESOURCE) {
4166 // If the input is a ResourceGather, we need to add
4167 // ReadVariableOp.
4168 NodeDef* variable_node = nullptr;
4169 TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(0), &variable_node));
4170 NodeDef* read_var_node = ctx().optimized_graph->add_node();
4171 read_var_node->set_name(OptimizedNodeName(
4172 ParseNodeScopeAndName(reduction_node->name()), "ReadVar"));
4173 read_var_node->set_op("ReadVariableOp");
4174 read_var_node->add_input(gather_node->input(0));
4175 read_var_node->set_device(variable_node->device());
4176
4177 // The Variable and the Gather node should have the same
4178 // dtype, but it might not be set on both nodes.
4179 auto attr = read_var_node->mutable_attr();
4180 if (variable_node->attr().count("dtype")) {
4181 SetAttrValue(variable_node->attr().at("dtype").type(),
4182 &(*attr)["dtype"]);
4183 }
4184 if (gather_node->attr().count("dtype")) {
4185 SetAttrValue(gather_node->attr().at("dtype").type(), &(*attr)["dtype"]);
4186 }
4187 // Copy the _class attr from the Gather node should it exist in case
4188 // of location constraints with the variable.
4189 if (gather_node->attr().count("_class")) {
4190 (*attr)["_class"] = gather_node->attr().at("_class");
4191 }
4192 if (variable_node->attr().count("shape")) {
4193 SetAttrValue(variable_node->attr().at("shape").shape(),
4194 &(*attr)["_output_shapes"]);
4195 }
4196
4197 ctx().node_map->AddNode(read_var_node->name(), read_var_node);
4198 reduction_node->set_input(0, read_var_node->name());
4199 ctx().node_map->UpdateInput(reduction_node->name(),
4200 reduction_node->input(0),
4201 read_var_node->name());
4202 } else {
4203 reduction_node->set_input(0, gather_node->input(0));
4204 ctx().node_map->UpdateInput(reduction_node->name(),
4205 reduction_node->input(0),
4206 gather_node->input(0));
4207 }
4208 *simplified_node_name = reduction_node->name();
4209 return OkStatus();
4210 }
4211
4212 private:
IsAxis0(const NodeDef & node,int axis_input)4213 bool IsAxis0(const NodeDef& node, int axis_input) {
4214 Tensor axis_tensor;
4215 if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
4216 return false;
4217 if (axis_tensor.NumElements() != 1) return false;
4218 if (axis_tensor.dtype() == DT_INT32) {
4219 return axis_tensor.flat<int32>()(0) == 0;
4220 } else if (axis_tensor.dtype() == DT_INT64) {
4221 return axis_tensor.flat<int64_t>()(0) == 0;
4222 } else {
4223 return false;
4224 }
4225 }
4226 };
4227
4228 // Eliminates unnecessary casts before sparse segment reduction operations.
4229 //
4230 // Existing graphs and library code would often insert a cast from DT_INT64 to
4231 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
4232 // Support for DT_INT64 indices and/or segment_ids now exists, so we can
4233 // pass the input directly without a cast.
4234 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
4235 public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4236 explicit RemoveCastIntoSegmentReductionStage(
4237 const GraphOptimizerContext& ctx,
4238 const ArithmeticOptimizerContext& ctx_ext)
4239 : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
4240 ctx_ext) {}
4241 ~RemoveCastIntoSegmentReductionStage() override = default;
4242
IsSupported(const NodeDef * node) const4243 bool IsSupported(const NodeDef* node) const override {
4244 return IsAnySparseSegmentReduction(*node);
4245 }
4246
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4247 Status TrySimplify(NodeDef* reduction_node,
4248 string* simplified_node_name) override {
4249 if (IsInPreserveSet(*reduction_node)) return OkStatus();
4250
4251 bool optimized = false;
4252
4253 // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
4254 // DT_INT64.
4255 std::array<std::pair<int, string>, 2> input_details = {
4256 std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
4257
4258 for (const auto& input : input_details) {
4259 int input_index = input.first;
4260 const string& type_attr_name = input.second;
4261 NodeDef* cast_node = nullptr;
4262 TF_RETURN_IF_ERROR(
4263 GetInputNode(reduction_node->input(input_index), &cast_node));
4264 DataType original_index_type;
4265 if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
4266 reduction_node->set_input(input_index, cast_node->input(0));
4267 ctx().node_map->UpdateInput(reduction_node->name(),
4268 reduction_node->input(1),
4269 cast_node->input(0));
4270 SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
4271 optimized = true;
4272 }
4273 }
4274
4275 if (optimized) *simplified_node_name = reduction_node->name();
4276 return OkStatus();
4277 }
4278
4279 private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)4280 bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
4281 if (!IsCast(node)) return false;
4282 if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
4283 return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
4284 }
4285 };
4286
4287 } // namespace
4288
SimplifyArithmeticOps(bool can_use_shapes)4289 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
4290 SetVector<NodeDef*> nodes_to_simplify;
4291 nodes_to_simplify.Reserve(optimized_graph_->node_size());
4292 for (int i = 0; i < optimized_graph_->node_size(); ++i) {
4293 nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
4294 }
4295
4296 const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
4297 graph_properties_.get(), node_map_.get(),
4298 &feed_nodes_, opt_level_);
4299 const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
4300
4301 // Stop pipeline after first stage returning non-empty simplified tensor
4302 // name.
4303 const auto stop = [](const string& result) { return !result.empty(); };
4304 GraphOptimizerStagePipeline<string> pipeline(stop);
4305 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
4306
4307 if (options_.combine_add_to_addn && can_use_shapes)
4308 pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
4309 if (options_.fold_conjugate_into_transpose)
4310 pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
4311 if (options_.fold_multiply_into_conv)
4312 pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
4313 if (options_.fold_transpose_into_matmul)
4314 pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
4315 if (is_aggressive && options_.hoist_common_factor_out_of_aggregation &&
4316 can_use_shapes)
4317 pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
4318 if (options_.minimize_broadcasts && can_use_shapes)
4319 pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
4320 if (options_.remove_identity_transpose && can_use_shapes)
4321 pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
4322 if (options_.remove_involution)
4323 pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
4324 if (options_.remove_redundant_bitcast)
4325 pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
4326 if (options_.remove_redundant_cast)
4327 pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
4328 if (options_.replace_pack_with_tile_reshape)
4329 pipeline.AddStage<ReplacePackWithTileReshape>(ctx, ctx_ext);
4330 if (options_.replace_mul_with_tile && can_use_shapes)
4331 pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
4332 if (options_.reduce_upsampling_dims)
4333 pipeline.AddStage<ReduceUpsamplingDims>(ctx, ctx_ext);
4334 if (options_.remove_redundant_reshape && can_use_shapes)
4335 pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
4336 if (options_.remove_negation)
4337 pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
4338 if (options_.replace_mul_with_square)
4339 pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
4340 if (options_.remove_logical_not)
4341 pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
4342 if (options_.reorder_cast_like_and_value_preserving)
4343 pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
4344 if (options_.simplify_aggregation)
4345 pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
4346 if (options_.hoist_cwise_unary_chains)
4347 pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
4348 if (options_.convert_sqrt_div_to_rsqrt_mul)
4349 pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
4350 if (options_.remove_idempotent)
4351 pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
4352 if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
4353 if (options_.convert_log1p)
4354 pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
4355 if (options_.convert_log_softmax)
4356 pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
4357 if (options_.optimize_max_or_min_of_monotonic)
4358 pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
4359 if (options_.convert_expm1)
4360 pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
4361 if (options_.unary_ops_composition)
4362 pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
4363 if (options_.remove_stack_slice_same_axis)
4364 pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
4365 if (options_.simplify_embedding_lookup)
4366 pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
4367 if (options_.remove_cast_into_segment_reduction)
4368 pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
4369 if (options_.fuse_squared_diff)
4370 pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
4371
4372 VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
4373 << absl::StrJoin(pipeline.StageNames(), ", ");
4374
4375 while (!nodes_to_simplify.Empty()) {
4376 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4377 NodeDef* node = nodes_to_simplify.PopBack();
4378
4379 string simplified_tensor = "";
4380 bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
4381
4382 // If the node was not optimized by any of the stages, go to the next one.
4383 if (!optimized) continue;
4384
4385 // re-wire consumers of an old node to the new one
4386 if (NodeName(simplified_tensor) != node->name()) {
4387 // Always consider simplified_tensor for further optimizations.
4388 NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
4389 if (simplified_node != nullptr) {
4390 nodes_to_simplify.PushBack(simplified_node);
4391 }
4392 // When `node` is simplified to another node rather than in-place, the
4393 // consumers of `node` are already redirected to `simplified_tensor`.
4394 // Re-push the consumers into `nodes_to_simplify` for further
4395 // optimizations.
4396 const std::vector<NodeDef*> consumers =
4397 node_map_->GetOutputsOrderedByNodeName(node->name());
4398 for (NodeDef* consumer : consumers) {
4399 // Update `consumer`'s use of `node` to `input`'s operand.
4400 for (int i = 0; i < consumer->input_size(); ++i) {
4401 int operand_pos;
4402 string operand_node_name =
4403 ParseNodeName(consumer->input(i), &operand_pos);
4404 if (operand_node_name == node->name()) {
4405 *consumer->mutable_input(i) =
4406 (operand_pos < 0
4407 ? AsControlDependency(NodeName(simplified_tensor))
4408 : simplified_tensor);
4409 }
4410 }
4411 node_map_->UpdateInput(consumer->name(), node->name(),
4412 simplified_tensor);
4413 nodes_to_simplify.PushBack(consumer);
4414 }
4415 }
4416 }
4417 return OkStatus();
4418 }
4419
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)4420 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
4421 const GrapplerItem& item,
4422 GraphDef* optimized_graph) {
4423 // Set up helper data structures.
4424 nodes_to_preserve_ = item.NodesToPreserve();
4425 fetch_nodes_known_ = !item.fetch.empty();
4426 GrapplerItem optimized_item(item);
4427 optimized_graph_ = &optimized_item.graph;
4428
4429 node_map_.reset(new NodeMap(optimized_graph_));
4430 for (const auto& feed : item.feed) {
4431 feed_nodes_.insert(NodeName(feed.first));
4432 }
4433
4434 // // Disable restricted graph rewrites.
4435 options_.unary_ops_composition &=
4436 item.optimization_options().allow_non_differentiable_rewrites;
4437
4438 // Perform topological sort on the graph in order to help DedupComputations
4439 // and AddOpsRewrite to optimize larger subgraphs starting from the roots
4440 // with more inputs.
4441 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
4442 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4443
4444 graph_properties_.reset(new GraphProperties(optimized_item));
4445 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4446 const Status status =
4447 graph_properties_->InferStatically(assume_valid_feeds,
4448 /*aggressive_shape_inference=*/false,
4449 /*include_tensor_values=*/false);
4450 const bool can_use_shapes = status.ok();
4451 if (!can_use_shapes) {
4452 VLOG(1) << "Shape inference failed." << status.error_message();
4453 }
4454
4455 // Perform the optimizations.
4456 TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
4457 *optimized_graph = std::move(*optimized_graph_);
4458 return OkStatus();
4459 }
4460
4461 } // namespace grappler
4462 } // namespace tensorflow
4463