• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/graph_topology_view.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
42 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
43 #include "tensorflow/core/grappler/utils.h"
44 #include "tensorflow/core/grappler/utils/canonicalizer.h"
45 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
46 #include "tensorflow/core/grappler/utils/topological_sort.h"
47 #include "tensorflow/core/grappler/utils/traversal.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/hash/hash.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/tensor_coding.h"
56 #include "tensorflow/core/protobuf/error_codes.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/saved_tensor_slice_util.h"
59 #include "tensorflow/core/util/strided_slice_op.h"
60 
61 using tensorflow::strings::StrCat;
62 
63 namespace tensorflow {
64 namespace grappler {
65 namespace {
66 
67 // Mark nodes created or optimized by a stage with a tag.
68 constexpr char kAddOpsRewriteTag[] =
69     "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
70 constexpr char kMinimizeBroadcastsTag[] =
71     "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
72 
73 // Extract values from a Const op to `values`. Returns true if succeeds.
74 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)75 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
76   if (node.op() != "Const") {
77     return false;
78   }
79 
80   if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
81       node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
82     return false;
83   }
84 
85   // TensorProto represents the content of the tensor in either <type>_val or
86   // tensor_content.
87   const TensorProto& tensor = node.attr().at("value").tensor();
88   typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
89       checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
90 
91   if (!tensor_values->empty() && tensor.has_tensor_shape()) {
92     // When tensor_shape is set, theoretically the representation of the data
93     // could be compressed. So, before copying values to the returned vector,
94     // make sure no compression happens.
95     const TensorShapeProto& shape = tensor.tensor_shape();
96     if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
97       values->insert(values->end(), tensor_values->begin(),
98                      tensor_values->end());
99       return true;
100     }
101   }
102 
103   const auto tensor_content_size = tensor.tensor_content().size();
104   if (tensor_content_size > 0) {
105     CHECK_EQ(0, tensor_content_size % sizeof(T))
106         << "tensor_content_size (" << tensor_content_size
107         << ") is not a multiple of " << sizeof(T);
108     values->resize(tensor_content_size / sizeof(T));
109     port::CopyToArray(tensor.tensor_content(),
110                       reinterpret_cast<char*>(values->data()));
111     return true;
112   }
113 
114   return false;
115 }
116 
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)117 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
118                           GraphDef* graph, NodeMap* node_map) {
119   bool already_exists = false;
120   for (const string& input : node->input()) {
121     if (input == new_input || AsControlDependency(input) == new_input) {
122       already_exists = true;
123       break;
124     }
125   }
126   if (!already_exists) {
127     const string ctrl_dep =
128         ConstantFolding::AddControlDependency(new_input, graph, node_map);
129     node->add_input(ctrl_dep);
130     node_map->AddOutput(NodeName(new_input), node->name());
131   }
132   return !already_exists;
133 }
134 
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)135 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
136   (*node->mutable_attr())[attr_name].set_type(dtype);
137 }
138 
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)139 NodeDef* GetTailOfValuePreservingChain(
140     const NodeDef& node, const NodeMap& node_map,
141     const std::unordered_set<string>& nodes_to_preserve) {
142   auto is_value_preserving_non_branching = [&](const NodeDef& node) {
143     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
144            IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
145   };
146   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
147                         is_value_preserving_non_branching);
148 }
149 
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)150 NodeDef* GetTailOfIdempotentChain(
151     const NodeDef& node, const NodeMap& node_map,
152     const std::unordered_set<string>& nodes_to_preserve) {
153   auto is_idempotent_non_branching = [&](const NodeDef& node) {
154     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
155            IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
156   };
157   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
158                         is_idempotent_non_branching);
159 }
160 
161 // GetElementUnexhaustive tries to get the value of an element in a tensor and
162 // turn it into complex128 type. It only check for a limited number of data
163 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)164 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
165                             complex128* element) {
166   if (dtypes.find(t.dtype()) == dtypes.end()) return false;
167   switch (t.dtype()) {
168     case DT_BFLOAT16:
169       *element = complex128(t.flat<bfloat16>()(i));
170       return true;
171     case DT_HALF:
172       *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
173       return true;
174     case DT_INT32:
175       *element = complex128(t.flat<int32>()(i));
176       return true;
177     case DT_INT64:
178       *element = complex128(t.flat<int64_t>()(i));
179       return true;
180     case DT_FLOAT:
181       *element = complex128(t.flat<float>()(i));
182       return true;
183     case DT_DOUBLE:
184       *element = complex128(t.flat<double>()(i));
185       return true;
186     case DT_COMPLEX64:
187       *element = complex128(t.flat<complex64>()(i));
188       return true;
189     case DT_COMPLEX128:
190       *element = t.flat<complex128>()(i);
191       return true;
192     default:
193       return false;
194   }
195 }
196 
NodeIsOnCpu(const NodeDef & node)197 bool NodeIsOnCpu(const NodeDef& node) {
198   string task;
199   string device;
200   return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
201          absl::StrContains(device, DEVICE_CPU);
202 }
203 
204 // True if all regular (non-control) inputs reference the same node or if there
205 // are no non-control inputs
AllRegularInputsEqual(const NodeDef & node)206 bool AllRegularInputsEqual(const NodeDef& node) {
207   if (!HasRegularInputs(node)) return true;
208   for (int i = 1; i < node.input_size(); ++i) {
209     if (IsControlInput(node.input(i))) {
210       break;
211     }
212     if (node.input(0) != node.input(i)) {
213       return false;
214     }
215   }
216   return true;
217 }
218 
219 // Replace a node with NoOp and reset shape inference results for it..
ReplaceWithNoOp(NodeDef * node,const GraphOptimizerContext & ctx)220 void ReplaceWithNoOp(NodeDef* node, const GraphOptimizerContext& ctx) {
221   ctx.node_map->RemoveInputs(node->name());
222   ctx.graph_properties->ClearInputProperties(node->name());
223   ctx.graph_properties->ClearOutputProperties(node->name());
224   EraseRegularNodeAttributes(node);
225   node->set_op("NoOp");
226   node->clear_input();
227 }
228 
229 // Graph optimizer context extension specific to ArithmeticOptimizer.
230 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon2c8062030111::ArithmeticOptimizerContext231   explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
232       : nodes_to_simplify(nodes_to_simplify) {}
233   SetVector<NodeDef*>* nodes_to_simplify;
234 };
235 
236 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
237 // AddOps optimization, etc...
238 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
239  public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)240   explicit ArithmeticOptimizerStage(const string& name,
241                                     const GraphOptimizerContext& ctx,
242                                     const ArithmeticOptimizerContext ctx_ext)
243       : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
244         ctx_ext_(ctx_ext) {}
245   ~ArithmeticOptimizerStage() override = default;
246 
247  protected:
248   // Simplification graph rewrite can create additional nodes that are inputs
249   // to final simplified node, they can be also added to the arithmetic
250   // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)251   void AddToOptimizationQueue(NodeDef* node) {
252     ctx_ext_.nodes_to_simplify->PushBack(node);
253   }
254 
255   // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)256   Status UpdateConsumers(NodeDef* node, const string& new_input) {
257     const auto consumers = ctx().node_map->GetOutputs(node->name());
258     if (consumers.empty()) return OkStatus();
259     const TensorId new_tensor = ParseTensorName(new_input);
260     for (NodeDef* consumer : consumers) {
261       if (consumer->name() == new_tensor.node()) continue;
262       bool updated = false;
263       for (int i = 0; i < consumer->input_size(); ++i) {
264         const TensorId input_tensor = ParseTensorName(consumer->input(i));
265         if (input_tensor.node() == node->name()) {
266           if (new_tensor.index() < 0 && input_tensor.index() >= 0) {
267             // Overwriting a data input with a control input will make the graph
268             // invalid.
269             return errors::InvalidArgument(
270                 "Cannot override data input ", input_tensor.ToString(),
271                 " with control input ", new_tensor.ToString());
272           }
273           consumer->set_input(i, input_tensor.index() < 0
274                                      ? absl::StrCat("^", new_tensor.node())
275                                      : new_input);
276           ctx().node_map->UpdateInput(consumer->name(), node->name(),
277                                       new_input);
278           updated = true;
279         }
280       }
281       if (updated) {
282         DedupControlInputs(consumer);
283         AddToOptimizationQueue(consumer);
284       }
285     }
286     return OkStatus();
287   }
288 
289   // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
290   // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)291   void ForwardControlDependencies(
292       NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
293     for (const auto& src : src_nodes) {
294       for (int i = src->input_size() - 1; i >= 0; --i) {
295         if (IsControlInput(src->input(i))) {
296           *target_node->add_input() = src->input(i);
297           ctx().node_map->AddOutput(NodeName(src->input(i)),
298                                     target_node->name());
299         } else {
300           break;
301         }
302       }
303     }
304     DedupControlInputs(target_node);
305   }
306 
IsReallyConstant(const NodeDef & node) const307   bool IsReallyConstant(const NodeDef& node) const {
308     if (!IsConstant(node)) {
309       return false;
310     }
311     // If the node is fed it's not constant anymore.
312     return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
313   }
314 
IsInPreserveSet(const NodeDef & node) const315   bool IsInPreserveSet(const NodeDef& node) const {
316     return ctx().nodes_to_preserve->find(node.name()) !=
317            ctx().nodes_to_preserve->end();
318   }
319 
320   // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const321   bool IsDrivenByControlDependency(const NodeDef& node) const {
322     return std::any_of(
323         node.input().begin(), node.input().end(),
324         [](const string& input) { return IsControlInput(input); });
325   }
326 
327   // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const328   bool DrivesControlDependency(const NodeDef& node) const {
329     for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
330       for (int i = 0; i < output->input_size(); ++i) {
331         const TensorId tensor = ParseTensorName(output->input(i));
332         if (tensor.node() == node.name() && tensor.index() < 0) {
333           return true;
334         }
335       }
336     }
337     return false;
338   }
339 
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)340   bool GetTensorFromConstNode(const string& node_name_or_input,
341                               Tensor* tensor) {
342     const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
343     return node != nullptr && IsReallyConstant(*node) &&
344            CheckAttrExists(*node, "value").ok() &&
345            tensor->FromProto(node->attr().at("value").tensor());
346   }
347 
348  private:
349   // Extended context required for ArithmeticOptimizer.
350   const ArithmeticOptimizerContext ctx_ext_;
351 };
352 
353 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
354 // group of nodes from the optimized graph.
355 //
356 // * AddOpsRewrite:
357 //   Rewrite a group of Add/AddN with compact Add/AddN tree
358 //
359 // * MinimizeBroadcasts:
360 //   Rewrite a group of binary associative ops, reordering
361 //   inputs, to minimize the cost of broadcast
362 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
363  public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)364   explicit ArithmeticNodesGroupOptimizerStage(
365       const string& name, const GraphOptimizerContext& ctx,
366       const ArithmeticOptimizerContext ctx_ext)
367       : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
368   ~ArithmeticNodesGroupOptimizerStage() override = default;
369 
370   // Input name with a statically inferred shape from GraphProperties
371   struct InputAndShape {
InputAndShapetensorflow::grappler::__anon2c8062030111::ArithmeticNodesGroupOptimizerStage::InputAndShape372     InputAndShape(const string& input, const TensorShapeProto& shape)
373         : input(input), shape(shape) {}
374     string input;
375     TensorShapeProto shape;
376   };
377 
378   // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
379   // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
380   // obtained by graph traversal, starting from a root node.
381   struct OptimizedNodesGroup {
382     NodeDef* root_node;
383     TensorShapeProto root_shape;
384     // Optimized nodes that will be updated or removed by rewrite
385     std::vector<NodeDef*> optimized_nodes;
386     // Inputs to optimized nodes
387     std::vector<InputAndShape> inputs;
388   };
389 
TrySimplify(NodeDef * node,string * simplified_node_name)390   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
391     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
392 
393     OptimizedNodesGroup group;
394     TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
395 
396     if (!group.optimized_nodes.empty()) {
397       *simplified_node_name = RewriteOptimizedNodesGroup(group);
398     }
399 
400     return OkStatus();
401   }
402 
403  protected:
404   // Modify the optimized graph after nodes group was successfully identified
405   virtual string RewriteOptimizedNodesGroup(
406       const OptimizedNodesGroup& group) = 0;
407 
408   // Check if input can become a part of current optimized nodes group.
409   virtual bool IsAbsorbableByOptimizedNodesGroup(
410       const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
411 
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const412   Status AbsorbInputByOptimizedNodesGroup(const string& input,
413                                           OptimizedNodesGroup* group) const {
414     std::deque<const string*> input_tensors;
415     input_tensors.push_front(&input);
416 
417     while (!input_tensors.empty()) {
418       const string* input_tensor = input_tensors.front();
419       input_tensors.pop_front();
420 
421       // Get a node for the input tensor.
422       NodeDef* input_node;
423       TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
424 
425       if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
426         group->optimized_nodes.push_back(input_node);
427         for (int i = input_node->input_size() - 1; i >= 0; --i) {
428           const string& absorbed_node_input = input_node->input(i);
429           // TODO(ezhulenev): support control inputs
430           if (IsControlInput(absorbed_node_input)) continue;
431           input_tensors.push_front(&absorbed_node_input);
432         }
433       } else {
434         // If input node can't be absorbed, add it to OptimizedNodesGroup input.
435         const OpInfo::TensorProperties* properties;
436         TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
437         group->inputs.emplace_back(*input_tensor, properties->shape());
438       }
439     }
440 
441     return OkStatus();
442   }
443 
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const444   Status CreateOptimizedNodesGroup(NodeDef* root_node,
445                                    OptimizedNodesGroup* group) const {
446     const OpInfo::TensorProperties* root_node_output_properties;
447     TF_RETURN_IF_ERROR(
448         GetTensorProperties(root_node->name(), &root_node_output_properties));
449 
450     group->root_node = root_node;
451     group->root_shape = root_node_output_properties->shape();
452 
453     group->optimized_nodes.reserve(root_node->input_size());
454     for (int i = 0; i < root_node->input_size(); ++i) {
455       const string& input_i = root_node->input(i);
456       // TODO(ezhulenev): add support for control inputs
457       if (IsControlInput(input_i)) continue;
458       TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
459     }
460 
461     return OkStatus();
462   }
463 
464   // Check if all inputs can be broadcasted to the same shape
465   // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const466   bool HasAllInputsBroadcastableToShape(
467       const NodeDef& node, const OpInfo::TensorProperties& properties) const {
468     auto is_broadcastable = [this, &properties](const string& input) {
469       const OpInfo::TensorProperties* input_props;
470       Status has_input_properties = GetTensorProperties(input, &input_props);
471       return has_input_properties.ok() &&
472              ShapesBroadcastable(properties, *input_props);
473     };
474     return std::all_of(node.input().begin(), node.input().end(),
475                        is_broadcastable);
476   }
477 
ShapeSignature(const TensorShapeProto & shape) const478   string ShapeSignature(const TensorShapeProto& shape) const {
479     string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
480     for (int i = 0; i < shape.dim_size(); ++i)
481       strings::StrAppend(&signature, ":", shape.dim(i).size());
482     return signature;
483   }
484 
MarkWithTag(const StringPiece tag,NodeDef * node)485   void MarkWithTag(const StringPiece tag, NodeDef* node) {
486     AddNodeAttr(tag, true, node);
487   }
488 
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const489   void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
490                              const StringPiece tag) const {
491     AddNodeAttr(tag, true, group.root_node);
492     for (NodeDef* optimized_node : group.optimized_nodes) {
493       AddNodeAttr(tag, true, optimized_node);
494     }
495   }
496 
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const497   bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
498                          const NodeDef& node) const {
499     return group.root_node->device() == node.device();
500   }
501 
IsInPreserveSet(const NodeDef & node) const502   bool IsInPreserveSet(const NodeDef& node) const {
503     return ctx().nodes_to_preserve->find(node.name()) !=
504            ctx().nodes_to_preserve->end();
505   }
506 
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const507   bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
508     return HasNodeAttr(node, tag);
509   }
510 
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const511   bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
512                           const StringPiece tag2) const {
513     return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
514   }
515 };
516 
517 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
518 // original inputs of absorbed nodes.
519 //
520 // 1) All nodes must have the same device placement.
521 //
522 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
523 //    optimized to a single AddN node.
524 //
525 //                AddN_1
526 //             /    |    \
527 //          Add_1   z   Add_2       -> AddN(x, y, z, w, q, e)
528 //          /  \        /  \
529 //         x    y      w    Add_3
530 //                          / \
531 //                         q   e
532 //
533 // 3) If some nodes have different shape (it needs to be broadcastable to the
534 //    shape of a "root), tree is optimized to AddNs for symbolically equal
535 //    shapes, and a tree of Add ops, that minimize broadcasts.
536 //
537 //                AddN_1                                 Add
538 //             /    |    \                              /  \
539 //          Add_1   z   Add_2       ->               Add    w
540 //          /  \        /  \                        /   \
541 //         x    y      w    Add_3      AddN(x, y, q, e)  z
542 //                          / \
543 //                         q   e
544 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
545  public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)546   explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
547                               const ArithmeticOptimizerContext& ctx_ext)
548       : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
549   ~AddOpsRewriteStage() override = default;
550 
551   // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const552   bool IsSupported(const NodeDef* node) const override {
553     if (!CanOptimize(*node)) return false;
554 
555     // shape must be symbolically defined and all inputs compatible with it
556     const OpInfo::TensorProperties* properties;
557     Status has_properties = GetTensorProperties(node->name(), &properties);
558     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
559            HasAllInputsBroadcastableToShape(*node, *properties);
560   }
561 
562  protected:
563   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const564   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
565                                          const NodeDef& node) const override {
566     if (!CanOptimize(node)) return false;
567 
568     if (!IsOnTheSameDevice(group, node)) {
569       return false;
570     }
571     // with a single output data consumer (presumably if we reach this node from
572     // previously absorbed or a root node, it means that this node is not used
573     // as an input to any other op, outside of the group)
574     if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
575       return false;
576     }
577     // All input shapes must be broadcastable to the node shape
578     const OpInfo::TensorProperties* properties;
579     Status has_properties = GetTensorProperties(node.name(), &properties);
580     return has_properties.ok() &&
581            HasAllInputsBroadcastableToShape(node, *properties);
582   }
583 
584   // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const585   bool CanOptimize(const NodeDef& node) const {
586     // TODO(ezhulenev): check if AccumulateNV2 can be supported too
587     if (!IsAdd(node) && !IsAddN(node)) {
588       return false;
589     }
590     if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
591       return false;
592     }
593     // TODO(ezhulenev): relax this condition for root node
594     return !(IsDrivenByControlDependency(node) ||
595              DrivesControlDependency(node));
596   }
597 
598   // Rewrite a group of add ops into a single AddN if all input shapes are
599   // symbolically equal. If not, create AddN for equal shapes first, and then
600   // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)601   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
602     VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
603             << " op=" << group.root_node->op()
604             << " num_optimized_nodes=" << group.optimized_nodes.size()
605             << " num_inputs=" << group.inputs.size();
606 
607     // Do not optimize any of the nodes that are part of this group.
608     MarkAllMembersWithTag(group, kAddOpsRewriteTag);
609 
610     // All new nodes will be placed under the scope of a root node.
611     auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
612 
613     // Find what shapes are present in the inputs of absorbed nodes.
614     std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
615     for (const auto& input : group.inputs) {
616       shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
617     }
618 
619     using SigKV = decltype(shape_sig_to_inputs)::value_type;
620     VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
621             << " unique shapes: "
622             << absl::StrJoin(shape_sig_to_inputs, ", ",
623                              [](string* out, SigKV p) {
624                                strings::StrAppend(out, p.first);
625                              });
626 
627     // Collect all the shapes from representative elements.
628     std::vector<TensorShapeProto> shapes;
629     shapes.reserve(shape_sig_to_inputs.size());
630     for (const auto& el : shape_sig_to_inputs)
631       shapes.push_back(el.second[0].shape);
632 
633     // If all inputs have the same shape, rewrite whole group with a single AddN
634     if (shapes.size() == 1) {
635       string node_name = UniqueOptimizedNodeName(root_scope_and_name);
636       AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
637                                         group.inputs);
638       return node_name;
639     }
640 
641     // For inputs of different shapes:
642     // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
643     // 2. Build a tree of Add nodes, minimizing cost of broadcast
644     std::sort(shapes.begin(), shapes.end(),
645               [](const TensorShapeProto& left, const TensorShapeProto& right) {
646                 return CompareSymbolicallyShapedTensorSizes(left, right);
647               });
648 
649     // optimized name for leaf AddN nodes
650     auto leaf_node_name = [&root_scope_and_name, this](int i) {
651       return UniqueOptimizedNodeName(root_scope_and_name,
652                                      strings::StrCat("Leaf_", i));
653     };
654     // optimized name for internal nodes of a tree built up from AddN leaves
655     auto internal_node_name = [&root_scope_and_name, this](int i) {
656       return UniqueOptimizedNodeName(root_scope_and_name,
657                                      strings::StrCat("Internal_", i));
658     };
659 
660     // Add/AddN nodes that must be added to the tree
661     std::deque<InputAndShape> add_ops;
662 
663     // Prepare leaf AddN nodes for inputs of equal shape
664     for (int i = 0, end = shapes.size(); i < end; ++i) {
665       const auto node_name = leaf_node_name(i);
666       const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
667       add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
668                                                           node_name, inputs));
669     }
670 
671     // Build up a tree of Add ops
672     int internal_nodes = 0;
673     do {
674       const InputAndShape lhs = add_ops.front();
675       add_ops.pop_front();
676       const InputAndShape rhs = add_ops.front();
677       add_ops.pop_front();
678       string name = add_ops.empty()
679                         ? UniqueOptimizedNodeName(root_scope_and_name)
680                         : internal_node_name(internal_nodes++);
681       InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
682       add_ops.push_front(add);
683     } while (add_ops.size() > 1);
684 
685     InputAndShape optimized_root_node = add_ops.front();
686     return optimized_root_node.input;
687   }
688 
689   // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)690   InputAndShape AddInputsOfSymbolicallyEqualShape(
691       const NodeDef& root_node, const string& node_name,
692       const std::vector<InputAndShape>& inputs) {
693     CHECK(!inputs.empty()) << "Inputs must be non-empty";
694 
695     // Do not create redundant AddN nodes
696     if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
697       return inputs[0];
698     }
699 
700     // get shape from representative element
701     auto shape = inputs[0].shape;
702 
703     // copy attributes from a root node
704     DataType dtype = root_node.attr().at("T").type();
705 
706     // add new AddN node
707     NodeDef* node = AddEmptyNode(node_name);
708     node->set_op("AddN");
709     node->set_device(root_node.device());
710     (*node->mutable_attr())["T"].set_type(dtype);
711     (*node->mutable_attr())["N"].set_i(inputs.size());
712 
713     for (const auto& inputAndShape : inputs) {
714       ctx().node_map->AddOutput(inputAndShape.input, node_name);
715       node->add_input(inputAndShape.input);
716     }
717 
718     MarkWithTag(kAddOpsRewriteTag, node);
719     return InputAndShape(node_name, shape);
720   }
721 
722   // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)723   InputAndShape AddAggregatedInputs(const NodeDef& root_node,
724                                     const string& node_name,
725                                     const InputAndShape& left,
726                                     const InputAndShape& right) {
727     // copy attributes from a root node
728     DataType dtype = root_node.attr().at("T").type();
729 
730     // add new Add node
731     NodeDef* node = AddEmptyNode(node_name);
732     node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
733                                                                 : "AddV2");
734     node->set_device(root_node.device());
735     (*node->mutable_attr())["T"].set_type(dtype);
736     node->add_input(left.input);
737     node->add_input(right.input);
738 
739     ctx().node_map->AddOutput(left.input, node_name);
740     ctx().node_map->AddOutput(right.input, node_name);
741 
742     MarkWithTag(kAddOpsRewriteTag, node);
743     return InputAndShape(
744         node_name, TensorShapeProto());  // shape is not important at this point
745   }
746 };
747 
748 // Use the distributive property of multiplication and division over addition,
749 // along with commutativity of the former, to hoist common factors/denominators
750 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
751 // This pattern occurs frequently in regularization terms for the gradients
752 // during training.
753 //
754 // For example, we can rewrite an expression of the form:
755 //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
756 // to the following:
757 //   Mul(x, AddN(y1, y2, y3, ... yn))
758 // For division, we can rewrite
759 //   AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
760 // to:
761 //   Div(AddN(y1, y2, y3, ... yn), x)
762 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
763  public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)764   explicit HoistCommonFactorOutOfAggregation(
765       const GraphOptimizerContext& ctx,
766       const ArithmeticOptimizerContext& ctx_ext)
767       : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
768   ~HoistCommonFactorOutOfAggregation() override = default;
769 
IsSupported(const NodeDef * node) const770   bool IsSupported(const NodeDef* node) const override {
771     return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
772            !IsRewritten(node);
773   }
774 
TrySimplify(NodeDef * node,string * simplified_node_name)775   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
776     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
777 
778     bool common_factor_is_denominator = false;
779     std::set<string> common_factors;
780     std::vector<string> ctrl_deps;
781     TF_RETURN_IF_ERROR(GetCommonFactors(
782         node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
783 
784     if (common_factors.size() == 1) {
785       const string& common_factor = *common_factors.begin();
786 
787       // Gather up the non-shared factors
788       bool shapes_match = true;
789       std::vector<string> unique_factors;
790       TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
791                                           common_factor_is_denominator,
792                                           &shapes_match, &unique_factors));
793 
794       if (shapes_match) {
795         NodeDef* input_0;
796         TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
797 
798         // Use a copy of the first node for the outer multiplication/division.
799         NodeDef* new_outer_node = AddCopyNode(
800             OuterNodeName(node, common_factor_is_denominator), input_0);
801         // And a copy of aggregation node as one of the inner operands
802         NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
803 
804         new_outer_node->set_device(node->device());
805         if (common_factor_is_denominator) {
806           new_outer_node->set_input(0, new_add_node->name());
807           new_outer_node->set_input(1, common_factor);
808         } else {
809           new_outer_node->set_input(0, common_factor);
810           new_outer_node->set_input(1, new_add_node->name());
811         }
812 
813         ctx().node_map->AddOutput(common_factor, new_outer_node->name());
814         ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
815 
816         // Hoist non-shared factors up into the new AddN node.
817         for (int i = 0, end = unique_factors.size(); i < end; ++i) {
818           const string& unique_factor_i = unique_factors[i];
819           new_add_node->set_input(i, unique_factor_i);
820           ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
821         }
822 
823         // Add control deps on add node
824         for (const string& ctrl_dep : ctrl_deps) {
825           *new_add_node->add_input() = ctrl_dep;
826           ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
827         }
828 
829         // optimize new inner aggregation node
830         AddToOptimizationQueue(new_add_node);
831         // do not optimize the same node twice
832         rewritten_nodes_.insert(node->name());
833         *simplified_node_name = new_outer_node->name();
834       }
835     }
836     return OkStatus();
837   }
838 
839  private:
840   // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const841   string OuterNodeName(const NodeDef* node, bool is_div) const {
842     auto scope_and_name = ParseNodeScopeAndName(node->name());
843     return is_div ? OptimizedNodeName(scope_and_name, "Div")
844                   : OptimizedNodeName(scope_and_name, "Mul");
845   }
846 
847   // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const848   string InnerAddNodeName(const NodeDef* node) const {
849     auto scope_and_name = ParseNodeScopeAndName(node->name());
850     return OptimizedNodeName(scope_and_name, "AddV2");
851   }
852 
853   // Determine the set of common factors if the input nodes are all Mul or
854   // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const855   Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
856                           bool* common_factor_is_denominator,
857                           std::vector<string>* ctrl_deps) const {
858     CHECK(common_factors->empty());
859     CHECK_NOTNULL(common_factor_is_denominator);
860     *common_factor_is_denominator = false;
861 
862     bool has_mul = false;
863     bool has_div = false;
864     for (int i = 0; i < node->input_size(); ++i) {
865       if (i > 0 && common_factors->empty()) break;
866       if (IsControlInput(node->input(i))) {
867         ctrl_deps->push_back(node->input(i));
868         continue;
869       }
870       NodeDef* input;
871       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
872 
873       if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
874           (IsAnyDiv(*input) && has_mul)) {
875         // Break if input is neither a Mul or Div, or if there are both Mul &
876         // Div Ops.
877         common_factors->clear();
878         break;
879       } else if (IsAnyDiv(*input)) {
880         has_div = true;
881         // In case of possible common dividers, we avoid hoisting out if any
882         // input is not float/double, since integer division is not distributive
883         // over addition.
884         const OpInfo::TensorProperties* properties0;
885         const OpInfo::TensorProperties* properties1;
886         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
887         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
888         if (properties0->dtype() != DT_FLOAT &&
889             properties0->dtype() != DT_DOUBLE &&
890             properties1->dtype() != DT_FLOAT &&
891             properties1->dtype() != DT_DOUBLE) {
892           common_factors->clear();
893           break;
894         }
895       } else if (IsMul(*input)) {
896         has_mul = true;
897       }
898 
899       // We only focus on common factors from denominators if any Op is a
900       // Div.
901       std::set<string> factors_i =
902           has_mul ? std::set<string>{input->input(0), input->input(1)}
903                   : std::set<string>{input->input(1)};
904       if (i == 0) {
905         std::swap(*common_factors, factors_i);
906       } else {
907         std::set<string> intersection;
908         std::set_intersection(
909             factors_i.begin(), factors_i.end(), common_factors->begin(),
910             common_factors->end(),
911             std::inserter(intersection, intersection.begin()));
912         std::swap(*common_factors, intersection);
913       }
914       for (int i = 2; i < input->input_size(); ++i) {
915         ctrl_deps->push_back(input->input(i));
916       }
917     }
918 
919     *common_factor_is_denominator = has_div;
920     return OkStatus();
921   }
922 
923   // Gather up the non-shared factors (the y's in the example).
924   // Unless the aggregation is Add, we have to make sure that all the y's
925   // have the same shape since the other aggregation ops do not support
926   // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const927   Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
928                           const bool common_factor_is_denominator,
929                           bool* shapes_match,
930                           std::vector<string>* unique_factors) const {
931     *shapes_match = true;
932     unique_factors->reserve(node->input_size());
933 
934     for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
935       const string& input = node->input(i);
936       if (IsControlInput(input)) {
937         break;
938       }
939       NodeDef* inner_node;
940       TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
941       const int unique_factor_index =
942           common_factor_is_denominator
943               ? 0
944               : (inner_node->input(0) == common_factor ? 1 : 0);
945       unique_factors->push_back(inner_node->input(unique_factor_index));
946       if (i > 0 && !IsAdd(*node)) {
947         const OpInfo::TensorProperties* lhs;
948         const OpInfo::TensorProperties* rhs;
949         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
950         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
951         *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
952       }
953     }
954     return OkStatus();
955   }
956 
IsRewritten(const NodeDef * node) const957   bool IsRewritten(const NodeDef* node) const {
958     // if graph rewrite happens in multiple passes without graph pruning between
959     // them, it's possible that rewritten node already exists in a graph
960     return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
961            ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
962            ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
963            ctx().node_map->NodeExists(InnerAddNodeName(node));
964   }
965 
966   // keep names of the nodes that were optimized by this stage
967   std::unordered_set<string> rewritten_nodes_;
968 };
969 
970 // Binary associative ops can be re-ordered to minimize the number of broadcasts
971 // and the size of a temporary tensors.
972 //
973 // Example: [a, c] - scalars, [b, d] - matrices
974 //   @ - binary associative op (Add or Mul)
975 //   @* - broadcast
976 //
977 //           @                      @*
978 //        /     \                /      \
979 //      @*       @*      ->     @        @
980 //    /   \    /   \          /   \    /   \
981 //   a     b  c     d        a     c  b     d
982 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
983  public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)984   explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
985                               const ArithmeticOptimizerContext& ctx_ext)
986       : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
987   }
988   ~MinimizeBroadcasts() override = default;
989 
IsSupported(const NodeDef * node) const990   bool IsSupported(const NodeDef* node) const override {
991     if (!IsBinaryAssociative(*node)) return false;
992 
993     if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
994       return false;
995 
996     // has a symbolically defined shape with broadcastable inputs
997     const OpInfo::TensorProperties* properties;
998     Status has_properties = GetTensorProperties(node->name(), &properties);
999     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
1000            HasAllInputsBroadcastableToShape(*node, *properties);
1001   }
1002 
1003  protected:
IsBinaryAssociative(const NodeDef & node) const1004   bool IsBinaryAssociative(const NodeDef& node) const {
1005     return IsMul(node) || IsAdd(node);
1006   }
1007 
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const1008   bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
1009     return group.root_node->op() == node.op();
1010   }
1011 
1012   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const1013   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
1014                                          const NodeDef& node) const override {
1015     if (!IsSameOp(group, node)) {
1016       return false;
1017     }
1018     if (IsInPreserveSet(node)) {
1019       return false;
1020     }
1021     // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
1022     if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
1023       return false;
1024     }
1025     if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
1026       return false;
1027     }
1028     if (!IsOnTheSameDevice(group, node)) {
1029       return false;
1030     }
1031     // Optimized nodes updated in place, and that would break the graph, if the
1032     // node has multiple output consumers
1033     if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
1034       return false;
1035     }
1036     // All input shapes must be broadcastable to the node shape
1037     const OpInfo::TensorProperties* properties;
1038     Status has_properties = GetTensorProperties(node.name(), &properties);
1039     return has_properties.ok() &&
1040            HasAllInputsBroadcastableToShape(node, *properties);
1041   }
1042 
CountUniqueShapes(const std::vector<InputAndShape> & inputs)1043   std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
1044     std::set<string> sigs;
1045     for (const auto& ias : inputs) {
1046       sigs.insert(ShapeSignature(ias.shape));
1047     }
1048     return sigs.size();
1049   }
1050 
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)1051   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
1052     VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
1053             << " op=" << group.root_node->op()
1054             << " num_optimized_nodes=" << group.optimized_nodes.size();
1055 
1056     // Do not optimize any of the nodes that are part of this group.
1057     MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
1058 
1059     if (CountUniqueShapes(group.inputs) <= 1) {
1060       VLOG(3) << "Skip min-bcast group with single unique shape";
1061       // nothing to optimize when all shapes are the same
1062       return group.root_node->name();
1063     }
1064 
1065     auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1066     auto num_inputs = group.inputs.size();
1067     CHECK_EQ(num_nodes, num_inputs - 1)
1068         << "Can't build a tree with " << num_inputs << " inputs, using "
1069         << num_nodes << "binary op nodes.";
1070 
1071     std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1072     std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1073                                          group.optimized_nodes.end());
1074 
1075     // sort inputs by it's shape from smallest to largest
1076     std::stable_sort(add_ops.begin(), add_ops.end(),
1077                      [](const InputAndShape& lhs, const InputAndShape& rhs) {
1078                        return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1079                                                                    rhs.shape);
1080                      });
1081 
1082     // If there is an odd number of inputs, last one is the largest, and we want
1083     // to attach it to the root node, to build a well balanced tree.
1084     std::deque<InputAndShape> add_ops_leftover;
1085     if (add_ops.size() % 2 != 0) {
1086       add_ops_leftover.push_back(add_ops.back());
1087       add_ops.pop_back();
1088     }
1089 
1090     // At this point it's guaranteed that add_ops have even number of inputs.
1091     do {
1092       const InputAndShape lhs = add_ops.front();
1093       add_ops.pop_front();
1094       const InputAndShape rhs = add_ops.front();
1095       add_ops.pop_front();
1096 
1097       NodeDef* node;
1098       if (!optimized_nodes.empty()) {
1099         // re-purpose optimized nodes to build a new tree
1100         node = optimized_nodes.back();
1101         optimized_nodes.pop_back();
1102       } else {
1103         // or use root node if none optimized nodes left
1104         node = group.root_node;
1105       }
1106       InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1107 
1108       // Pushing updated node to the back of a deque will create a wide and
1109       // short tree, pushing to the front will create a tall tree. We prefer to
1110       // get a wide tree, it minimizes the potential number of temporary tensors
1111       // required to keep in memory, though sometimes we can go up to prevent
1112       // propagating a broadcast from leaves to the root. Example:
1113       //
1114       // inputs: [s, s, s, M] (s - scalar, M - matrix)
1115       // @* - op with broadcast
1116       //
1117       //  (only push_back)           @*     (push_front first op)
1118       //                            /  \
1119       //       @*                  @    M
1120       //     /   \                / \
1121       //    @     @*      ->     @   s
1122       //   / \   / \            / \
1123       //  s   s s   M          s   s
1124       if (add_ops.size() >= 2 &&
1125           CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1126                                                add_ops.at(1).shape)) {
1127         add_ops.push_front(updated_node);
1128       } else {
1129         add_ops.push_back(updated_node);
1130       }
1131     } while (add_ops.size() > 1);
1132     CHECK_EQ(1, add_ops.size());
1133 
1134     // attach the largest tensor to the root op
1135     if (!add_ops_leftover.empty()) {
1136       const InputAndShape lhs = add_ops.front();
1137       add_ops.pop_front();
1138       const InputAndShape rhs = add_ops_leftover.front();
1139       InputAndShape updated_node =
1140           UpdateInputs(lhs.input, rhs.input, group.root_node);
1141       add_ops.push_back(updated_node);
1142     }
1143 
1144     return add_ops.front().input;
1145   }
1146 
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1147   InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1148                              NodeDef* node) {
1149     string old_input_0 = node->input(0);
1150     string old_input_1 = node->input(1);
1151 
1152     // Update inputs only if they changed
1153     if (old_input_0 != input_0 || old_input_1 != input_1) {
1154       node->set_input(0, input_0);
1155       node->set_input(1, input_1);
1156       // Invalidate node properties (shape)
1157       ctx().graph_properties->ClearOutputProperties(node->name());
1158       ctx().graph_properties->ClearInputProperties(node->name());
1159       // Update the node map
1160       ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1161       ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1162       ctx().node_map->AddOutput(NodeName(input_0), node->name());
1163       ctx().node_map->AddOutput(NodeName(input_1), node->name());
1164       // Add updated node to optimization queue
1165       AddToOptimizationQueue(node);
1166     }
1167 
1168     TensorShapeProto shape;  // shape is not important at this point
1169     return InputAndShape(node->name(), shape);
1170   }
1171 };
1172 
1173 // Removes inverse transpose nodes
1174 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1175  public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1176   explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1177                                    const ArithmeticOptimizerContext& ctx_ext)
1178       : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1179   ~RemoveIdentityTranspose() override = default;
1180 
IsSupported(const NodeDef * node) const1181   bool IsSupported(const NodeDef* node) const override {
1182     return IsTranspose(*node) || IsConjugateTranspose(*node);
1183   }
1184 
TrySimplify(NodeDef * node,string * simplified_node_name)1185   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1186     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1187     NodeDef* tail = node;
1188     tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1189                                     *ctx().nodes_to_preserve);
1190     NodeDef* first_transpose;
1191     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1192 
1193     NodeDef* node_perm;
1194     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1195     if (!IsConstant(*node_perm)) {
1196       return OkStatus();
1197     }
1198     std::vector<int64_t> node_perm_values;
1199     TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1200     if (first_transpose->op() == node->op()) {
1201       // Remove pairs of transposes that cancel each other.
1202       NodeDef* first_transpose_perm;
1203       TF_RETURN_IF_ERROR(
1204           GetInputNode(first_transpose->input(1), &first_transpose_perm));
1205       if (!IsConstant(*first_transpose_perm)) {
1206         return OkStatus();
1207       }
1208       std::vector<int64_t> first_transpose_perm_values;
1209       TF_RETURN_IF_ERROR(
1210           GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1211       if (AreInversePermutations(node_perm_values,
1212                                  first_transpose_perm_values)) {
1213         if (tail == node) {
1214           // Bypass adjacent pair.
1215           *simplified_node_name = first_transpose->input(0);
1216         } else {
1217           // Bypass pair connected through chain.
1218           tail->set_input(0, first_transpose->input(0));
1219           ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1220                                       first_transpose->input(0));
1221           ForwardControlDependencies(tail, {first_transpose});
1222           *simplified_node_name = node->input(0);
1223         }
1224       }
1225     } else {
1226       // Remove simple identity transposes.
1227       if (IsIdentityPermutation(node_perm_values)) {
1228         if (IsConjugateTranspose(*node)) {
1229           const NodeScopeAndName transpose =
1230               ParseNodeScopeAndName(node->name());
1231           const string optimized_node_name = OptimizedNodeName(transpose);
1232           NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1233           new_op->set_op("Conj");
1234           new_op->mutable_input()->RemoveLast();
1235           new_op->mutable_attr()->erase("Tperm");
1236           ForwardControlDependencies(new_op, {node});
1237           *simplified_node_name = new_op->name();
1238         } else {
1239           *simplified_node_name = node->input(0);
1240         }
1241       }
1242     }
1243     return OkStatus();
1244   }
1245 
1246  private:
GetPermutation(const NodeDef & node_perm,std::vector<int64_t> * perm64) const1247   Status GetPermutation(const NodeDef& node_perm,
1248                         std::vector<int64_t>* perm64) const {
1249     std::vector<int> perm32;
1250     if (ValuesFromConstNode(node_perm, &perm32)) {
1251       perm64->reserve(perm32.size());
1252       for (int val : perm32) {
1253         perm64->push_back(static_cast<int64_t>(val));
1254       }
1255       return OkStatus();
1256     }
1257     if (ValuesFromConstNode(node_perm, perm64)) {
1258       return OkStatus();
1259     }
1260     return errors::InvalidArgument("Couldn't extract permutation from ",
1261                                    node_perm.name());
1262   }
1263 
AreInversePermutations(const std::vector<int64_t> & a,const std::vector<int64_t> & b)1264   bool AreInversePermutations(const std::vector<int64_t>& a,
1265                               const std::vector<int64_t>& b) {
1266     if (a.size() != b.size()) {
1267       return false;
1268     }
1269     for (int i = 0, end = a.size(); i < end; ++i) {
1270       if (a[b[i]] != i) {
1271         return false;
1272       }
1273     }
1274     return true;
1275   }
1276 
IsIdentityPermutation(const std::vector<int64_t> & perm)1277   bool IsIdentityPermutation(const std::vector<int64_t>& perm) {
1278     for (int64_t i = 0, end = perm.size(); i < end; ++i) {
1279       if (i != perm[i]) {
1280         return false;
1281       }
1282     }
1283     return true;
1284   }
1285 };
1286 
1287 // An involution is an element-wise function f(x) that is its own inverse,
1288 // i.e. f(f(x)) = x. If we can find a chain of ops
1289 //   f->op1->op2->...opn->f
1290 // where op1 through opn preserve the values of their inputs, we can remove
1291 // the two instances of the involution from the graph, since they cancel
1292 // each other.
1293 class RemoveInvolution : public ArithmeticOptimizerStage {
1294  public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1295   explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1296                             const ArithmeticOptimizerContext& ctx_ext)
1297       : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1298   ~RemoveInvolution() override = default;
1299 
IsSupported(const NodeDef * node) const1300   bool IsSupported(const NodeDef* node) const override {
1301     return IsInvolution(*node);
1302   }
1303 
TrySimplify(NodeDef * node,string * simplified_node_name)1304   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1305     NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1306                                                   *ctx().nodes_to_preserve);
1307 
1308     NodeDef* involution;
1309     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1310 
1311     if (involution->op() == node->op()) {
1312       // Skip both *node and *involution since they cancel each other.
1313       if (tail == node) {
1314         // The two nodes to eliminate are adjacent.
1315         *simplified_node_name = involution->input(0);
1316       } else {
1317         tail->set_input(0, involution->input(0));
1318         ctx().node_map->UpdateInput(tail->name(), involution->name(),
1319                                     involution->input(0));
1320         *simplified_node_name = node->input(0);
1321       }
1322     }
1323 
1324     return OkStatus();
1325   }
1326 };
1327 
1328 // Remove redundant Bitcasts.
1329 // 1) Remove Bitcast whose source type and destination type are equal
1330 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1331 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1332  public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1333   explicit RemoveRedundantBitcastStage(
1334       const GraphOptimizerContext& ctx,
1335       const ArithmeticOptimizerContext& ctx_ext)
1336       : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1337   ~RemoveRedundantBitcastStage() override = default;
1338 
IsSupported(const NodeDef * node) const1339   bool IsSupported(const NodeDef* node) const override {
1340     return IsBitcast(*node);
1341   }
1342 
TrySimplify(NodeDef * node,string * simplified_node_name)1343   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1344     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1345 
1346     // Bypass Bitcast whose source type and destination type are equal.
1347     AttrSlice attrs(*node);
1348     DataType input_type;
1349     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1350     DataType output_type;
1351     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1352     if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1353       *simplified_node_name = node->input(0);
1354       return OkStatus();
1355     }
1356 
1357     NodeDef* bitcast;
1358     TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1359     NodeDef* operand;
1360     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1361 
1362     if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1363       AttrSlice operand_attrs(*operand);
1364       DataType operand_input_type;
1365       TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1366       // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1367       bitcast->set_input(0, operand->input(0));
1368       SetDataTypeToAttr(operand_input_type, "T", bitcast);
1369       ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1370                                   operand->input(0));
1371       AddToOptimizationQueue(bitcast);
1372       *simplified_node_name = bitcast->name();
1373     }
1374 
1375     return OkStatus();
1376   }
1377 };
1378 
1379 // Remove Casts whose source type and destination type are equal.
1380 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1381  public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1382   explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1383                                     const ArithmeticOptimizerContext& ctx_ext)
1384       : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1385   ~RemoveRedundantCastStage() override = default;
1386 
IsSupported(const NodeDef * node) const1387   bool IsSupported(const NodeDef* node) const override {
1388     return IsCast(*node) && !IsInPreserveSet(*node);
1389   }
1390 
TrySimplify(NodeDef * node,string * simplified_node_name)1391   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1392     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1393 
1394     // Bypass Cast whose source type and destination type are equal.
1395     AttrSlice attrs(*node);
1396     DataType input_type;
1397     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1398     DataType output_type;
1399     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1400     if (input_type == output_type) {
1401       *simplified_node_name = node->input(0);
1402     }
1403     return OkStatus();
1404   }
1405 };
1406 
1407 class RemoveNegationStage : public ArithmeticOptimizerStage {
1408  public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1409   explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1410                                const ArithmeticOptimizerContext& ctx_ext)
1411       : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1412   ~RemoveNegationStage() override = default;
1413 
IsSupported(const NodeDef * node) const1414   bool IsSupported(const NodeDef* node) const override {
1415     return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1416   }
1417 
TrySimplify(NodeDef * node,string * simplified_node_name)1418   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1419     NodeDef* x;
1420     NodeDef* y;
1421     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1422     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1423     bool updated = false;
1424     if (IsNeg(*y)) {
1425       // a - (-b) = a + b or  a + (-b) = a - b
1426       ForwardControlDependencies(node, {y});
1427       ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1428       node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1429       node->set_input(1, y->input(0));
1430       updated = true;
1431     } else if (IsAdd(*node) && IsNeg(*x)) {
1432       // (-a) + b = b - a
1433       ForwardControlDependencies(node, {x});
1434       ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1435       node->set_op("Sub");
1436       node->mutable_input()->SwapElements(0, 1);
1437       node->set_input(1, x->input(0));
1438       updated = true;
1439     }
1440     if (updated) {
1441       AddToOptimizationQueue(node);
1442     }
1443     return OkStatus();
1444   }
1445 };
1446 
1447 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1448  public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1449   explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1450                                  const ArithmeticOptimizerContext& ctx_ext)
1451       : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1452   ~RemoveLogicalNotStage() override = default;
1453 
IsSupported(const NodeDef * node) const1454   bool IsSupported(const NodeDef* node) const override {
1455     return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1456   }
1457 
TrySimplify(NodeDef * node,string * simplified_node_name)1458   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1459     const string node_name = node->name();
1460     NodeDef* input;
1461     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1462     if (IsInPreserveSet(*input) ||
1463         NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1464       return OkStatus();
1465     }
1466     string new_op;
1467     if (IsEqual(*input)) {
1468       new_op = "NotEqual";
1469     } else if (IsNotEqual(*input)) {
1470       new_op = "Equal";
1471     } else if (IsLess(*input)) {
1472       new_op = "GreaterEqual";
1473     } else if (IsLessEqual(*input)) {
1474       new_op = "Greater";
1475     } else if (IsGreater(*input)) {
1476       new_op = "LessEqual";
1477     } else if (IsGreaterEqual(*input)) {
1478       new_op = "Less";
1479     }
1480     if (!new_op.empty()) {
1481       input->set_op(new_op);
1482       *simplified_node_name = input->name();
1483     }
1484     return OkStatus();
1485   }
1486 };
1487 
1488 // This optimization hoists the common prefix of unary ops of the inputs to
1489 // concat out of the concat, for example:
1490 //    Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1491 // becomes
1492 //    Exp(Sin(Concat([x, y, z]))).
1493 // Similarly, it will hoist the common postfix of unary ops into Split or
1494 // SplitV nodes, for example:
1495 //    [Exp(Sin(y)) for y in Split(x)]
1496 // becomes
1497 //    [y for y in Split(Exp(Sin(x))]
1498 //
1499 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1500 // on the concat/split node.
1501 // TODO(rmlarsen): Handle Enter/Exit.
1502 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1503  public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1504   explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1505                                       const ArithmeticOptimizerContext& ctx_ext)
1506       : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1507 
1508   ~HoistCWiseUnaryChainsStage() override = default;
1509 
1510   struct ChainLink {
1511     ChainLink() = default;
ChainLinktensorflow::grappler::__anon2c8062030111::HoistCWiseUnaryChainsStage::ChainLink1512     ChainLink(NodeDef* _node, int _port_origin)
1513         : node(_node), port_origin(_port_origin) {}
1514     NodeDef* node;    // Node in a chain.
1515     int port_origin;  // Port on concat/split node from which this chain
1516                       // originates.
1517 
operator <tensorflow::grappler::__anon2c8062030111::HoistCWiseUnaryChainsStage::ChainLink1518     bool operator<(const ChainLink& other) const {
1519       if (port_origin < other.port_origin) {
1520         return true;
1521       } else if (port_origin > other.port_origin) {
1522         return false;
1523       } else {
1524         return node->name() < other.node->name();
1525       }
1526     }
1527   };
1528 
1529   // We use an ordinary set sorted on port and node name, so the order, and
1530   // hence the node name used for the hoisted chain, will be deterministic.
1531   using ChainLinkSet = std::set<ChainLink>;
1532 
IsSupported(const NodeDef * node) const1533   bool IsSupported(const NodeDef* node) const override {
1534     if (IsInPreserveSet(*node)) return false;
1535     if (IsConcat(*node) && node->attr().count("N") != 0) {
1536       const int n = node->attr().at("N").i();
1537       return n > 1 && FirstNInputsAreUnique(*node, n);
1538     } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1539                node->attr().count("num_split") != 0) {
1540       const int num_split = node->attr().at("num_split").i();
1541       if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1542         // TODO(rmlarsen): Remove this constraint when we have optimizations
1543         // in place for merging slices into splits.
1544         return false;
1545       }
1546       if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1547         // TODO(ezhulenev): Unary ops after Split might have a control path to
1548         // the Split node, and we currently do not properly handle cycles.
1549         return false;
1550       }
1551       return num_split > 1 && !IsAlreadyOptimized(*node);
1552     }
1553     return false;
1554   }
1555 
TrySimplify(NodeDef * node,string * simplified_node_name)1556   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1557     node_is_concat_ = IsConcat(*node);
1558     int prefix_length;
1559     std::set<string> ctrl_inputs;
1560     ChainLinkSet tails;
1561     TF_RETURN_IF_ERROR(
1562         FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1563     if (prefix_length > 0 && !tails.empty()) {
1564       TF_RETURN_IF_ERROR(
1565           HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1566     }
1567     return OkStatus();
1568   }
1569 
1570  private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1571   bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1572     if (n > node.input_size()) return false;
1573     absl::flat_hash_set<string> unique_inputs;
1574     const int start = node.op() == "Concat" ? 1 : 0;
1575     const int end = start + n;
1576     for (int i = start; i < end; ++i) {
1577       unique_inputs.insert(node.input(i));
1578     }
1579     int unique_input_size = unique_inputs.size();
1580     return unique_input_size == n;
1581   }
1582 
1583   // Returns the length of the common unary chain of ops that can be
1584   // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1585   Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1586                                 ChainLinkSet* tails,
1587                                 std::set<string>* ctrl_inputs) const {
1588     *prefix_length = 0;
1589     // Follow the chains starting at each concat input or split output as long
1590     // as all the following conditions hold:
1591     //   1. The ops in all chains are the same.
1592     //   2. The ops are unary elementwise op.
1593     //   3. The op output has only a single consumer (concat only).
1594     ChainLinkSet cur_tails;
1595     TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1596     if (cur_tails.size() < 2) {
1597       return OkStatus();
1598     }
1599     ctrl_inputs->clear();
1600     bool stop = false;
1601     while (!stop && !cur_tails.empty() &&
1602            OpsAreSafeToHoist(root_node, cur_tails)) {
1603       // We found one more link that can be hoisted.
1604       ++(*prefix_length);
1605       tails->swap(cur_tails);
1606       GatherControlInputs(ctrl_inputs, *tails);
1607 
1608       // Advance tail pointers to the next level.
1609       TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1610     }
1611     return OkStatus();
1612   }
1613 
1614   // Hoists the chains to the other side of concat or split and attaches the
1615   // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1616   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1617                            std::set<string>* ctrl_inputs, NodeDef* root_node) {
1618     VLOG(3) << "Hoist unary op chain:"
1619             << " root=" << root_node->DebugString()
1620             << " prefix_length=" << prefix_length << " ctrl_inputs=["
1621             << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1622 
1623     if (tails.empty()) {
1624       return OkStatus();
1625     }
1626     AddToOptimizationQueue(root_node);
1627     optimized_nodes_.insert(root_node->name());
1628     if (node_is_concat_) {
1629       AddControlInputs(ctrl_inputs, root_node);
1630       return HoistChainForConcat(prefix_length, tails, root_node);
1631     } else {
1632       return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1633     }
1634   }
1635 
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1636   void GatherControlInputs(std::set<string>* ctrl_inputs,
1637                            const ChainLinkSet& ops) const {
1638     for (const auto& link : ops) {
1639       const NodeDef* node = link.node;
1640       for (int i = node->input_size() - 1; i >= 0; --i) {
1641         const string& input = node->input(i);
1642         if (!IsControlInput(input)) break;
1643         ctrl_inputs->insert(input);
1644       }
1645     }
1646   }
1647 
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1648   void AddControlInputs(std::set<string>* new_ctrl_inputs,
1649                         NodeDef* node) const {
1650     for (int i = node->input_size() - 1; i >= 0; --i) {
1651       const string& existing_input = node->input(i);
1652       if (!IsControlInput(existing_input)) break;
1653       new_ctrl_inputs->erase(existing_input);
1654     }
1655     for (const string& new_input : *new_ctrl_inputs) {
1656       ctx().node_map->AddOutput(NodeName(new_input), node->name());
1657       node->add_input(new_input);
1658     }
1659   }
1660 
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1661   Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1662     if (node_is_concat_) {
1663       // Handle concat nodes by looking backwards in the graph.
1664       TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1665       const int n = node.attr().at("N").i();
1666       const int start = node.op() == "Concat" ? 1 : 0;
1667       const int end = start + n;
1668       if (end > node.input_size()) {
1669         return errors::FailedPrecondition("Got attr N=", n,
1670                                           " without enough inputs.");
1671       }
1672       // Set up tail pointers to point to the immediate inputs to Concat.
1673       for (int input_port = start; input_port < end; ++input_port) {
1674         if (IsControlInput(node.input(input_port))) {
1675           return errors::FailedPrecondition(
1676               "Got control input ", node.input(input_port),
1677               " where normal input was expected.");
1678         }
1679         NodeDef* tail;
1680         TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1681         tails->insert(ChainLink(tail, input_port));
1682       }
1683       return OkStatus();
1684     } else {
1685       // Handle split nodes by looking forwards in the graph.
1686       const auto& outputs = ctx().node_map->GetOutputs(node.name());
1687       for (NodeDef* output : outputs) {
1688         if (output->input_size() == 0 || IsControlInput(output->input(0))) {
1689           continue;
1690         }
1691         TensorId tensor_id = ParseTensorName(output->input(0));
1692         if (tensor_id.node() == node.name()) {
1693           tails->insert(ChainLink(output, tensor_id.index()));
1694         } else {
1695           // This output node has a non-control input other than the split node,
1696           // abort.
1697           tails->clear();
1698           return OkStatus();
1699         }
1700       }
1701     }
1702     return OkStatus();
1703   }
1704 
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1705   bool OpsAreSafeToHoist(const NodeDef& root_node,
1706                          const ChainLinkSet& ops) const {
1707     if (ops.empty()) return true;
1708     const NodeDef* op0 = ops.begin()->node;
1709     if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1710     for (const auto& link : ops) {
1711       const NodeDef* op = link.node;
1712       if (op->device() != root_node.device() || op->op() != op0->op() ||
1713           IsInPreserveSet(*op)) {
1714         return false;
1715       }
1716       if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1717         // TODO(rmlarsen): Allow outgoing control edges.
1718         return false;
1719       }
1720       // Do not hoist Relu if it can be fused with its predecessors. This is
1721       // important because remapping runs after arithmetic.
1722       if (IsRelu(*op) || IsRelu6(*op)) {
1723         NodeDef* operand = nullptr;
1724         if (!GetInputNode(op->input(0), &operand).ok()) {
1725           return false;
1726         }
1727         if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1728           return false;
1729         }
1730       }
1731     }
1732     return true;
1733   }
1734 
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1735   Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1736                       bool* stop) const {
1737     *stop = true;
1738     new_tails->clear();
1739     for (const auto& link : tails) {
1740       const NodeDef* tail = link.node;
1741       if (node_is_concat_) {
1742         if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1743           return OkStatus();
1744         }
1745         NodeDef* new_tail;
1746         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1747         // Remember original port.
1748         new_tails->insert(ChainLink(new_tail, link.port_origin));
1749       } else {
1750         for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1751           const TensorId tensor = ParseTensorName(new_tail->input(0));
1752           if (tensor.node() != tail->name()) {
1753             return OkStatus();
1754           }
1755           // Skip control outputs.
1756           if (tensor.index() >= 0) {
1757             // Remember original port.
1758             new_tails->insert(ChainLink(new_tail, link.port_origin));
1759           }
1760         }
1761       }
1762     }
1763     *stop = false;
1764     return OkStatus();
1765   }
1766 
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1767   Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1768                              NodeDef* concat_node) {
1769     const string& concat_name = concat_node->name();
1770     const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1771     for (const auto& link : tails) {
1772       NodeDef* tail = CHECK_NOTNULL(link.node);
1773       const int concat_port = link.port_origin;
1774       CHECK_GE(concat_port, 0);
1775       CHECK_LT(concat_port, concat_node->input_size());
1776       const string concat_input = concat_node->input(concat_port);
1777       // Hook the node following tail directly into the concat node.
1778       const string tail_input = tail->input(0);
1779       concat_node->set_input(concat_port, tail_input);
1780       ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1781 
1782       if (concat_port == first_input) {
1783         // Update the consumers of concat to consume the end of the chain
1784         // instead.
1785         TF_RETURN_IF_ERROR(UpdateConsumers(concat_node, concat_input));
1786         // Reuse nodes in the first chain to process output of concat.
1787         tail->set_input(0, concat_name);
1788         ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1789       }
1790     }
1791     return OkStatus();
1792   }
1793 
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1794   Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1795                             std::set<string>* ctrl_inputs,
1796                             NodeDef* split_node) {
1797     // Create a new chain before the split node to process the input tensor.
1798     const string& split_name = split_node->name();
1799     auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1800 
1801     // We use the first tail node in the set as a template to get the list of
1802     // ops to apply (starting from the end).
1803     NodeDef* cur_tail = tails.begin()->node;
1804     NodeDef* cur_copy = AddCopyNode(
1805         OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1806     cur_copy->clear_input();
1807 
1808     // Update the split to take its input from the tail of the new chain.
1809     const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1810     const string orig_input = split_node->input(value_slot);
1811     split_node->set_input(value_slot, cur_copy->name());
1812     ctx().node_map->UpdateInput(split_node->name(), orig_input,
1813                                 cur_copy->name());
1814     TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1815 
1816     // Now walk backwards creating the rest of the chain.
1817     while (cur_tail != split_node) {
1818       NodeDef* new_copy = AddCopyNode(
1819           OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1820       new_copy->clear_input();
1821       cur_copy->add_input(new_copy->name());
1822       ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1823       cur_copy = new_copy;
1824       TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1825     }
1826     // Connect the original input to the head of the new chain.
1827     cur_copy->add_input(orig_input);
1828     ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1829                                  cur_copy->name());
1830     // Make sure all the control inputs are satisfied before running the first
1831     // node in the new chain.
1832     AddControlInputs(ctrl_inputs, cur_copy);
1833 
1834     // Connect all consumers of the tail nodes directly to the
1835     // output port of Split from which the chain started.
1836     for (const auto& link : tails) {
1837       TF_RETURN_IF_ERROR(UpdateConsumers(
1838           link.node, link.port_origin == 0
1839                          ? split_name
1840                          : strings::StrCat(split_name, ":", link.port_origin)));
1841     }
1842     return OkStatus();
1843   }
1844 
IsAlreadyOptimized(const NodeDef & node) const1845   bool IsAlreadyOptimized(const NodeDef& node) const {
1846     return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1847   }
1848 
1849  private:
1850   bool node_is_concat_;
1851   std::unordered_set<string> optimized_nodes_;
1852 };
1853 
1854 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1855  public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1856   explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1857                                  const ArithmeticOptimizerContext& ctx_ext)
1858       : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1859   ~RemoveIdempotentStage() override = default;
1860 
IsSupported(const NodeDef * node) const1861   bool IsSupported(const NodeDef* node) const override {
1862     return node->input_size() == 1 && IsIdempotent(*node) &&
1863            !IsInPreserveSet(*node);
1864   }
1865 
TrySimplify(NodeDef * node,string * simplified_node_name)1866   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1867     NodeDef* input;
1868     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1869     if (input->op() == node->op() && input->device() == node->device()) {
1870       *simplified_node_name = node->input(0);
1871     }
1872     return OkStatus();
1873   }
1874 };
1875 
1876 // Performs the conversion:
1877 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1878 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1879 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1880  public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1881   explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1882                                   const ArithmeticOptimizerContext& ctx_ext)
1883       : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1884   ~SqrtDivToRsqrtMulStage() override = default;
1885 
IsSupported(const NodeDef * node) const1886   bool IsSupported(const NodeDef* node) const override {
1887     // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1888     // for b == 0 would result in a / Inf instead of 0.
1889     return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1890   }
1891 
TrySimplify(NodeDef * node,string * simplified_node_name)1892   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1893     NodeDef* y;
1894     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1895     // Optimize only if divisor is a Sqrt whose output is not being consumed
1896     // elsewhere.
1897     if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1898         (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1899       if (IsXdivy(*node)) {
1900         // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1901         node->set_op("MulNoNan");
1902         node->mutable_input()->SwapElements(0, 1);
1903       } else {
1904         // div(a, sqrt(b)) => mul(a, rsqrt(b))
1905         node->set_op("Mul");
1906       }
1907       y->set_op("Rsqrt");
1908       AddToOptimizationQueue(node);
1909       AddToOptimizationQueue(y);
1910     }
1911     return OkStatus();
1912   }
1913 };
1914 
1915 // Performs the following conversion for real types:
1916 //   Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1917 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1918  public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1919   explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1920                                 const ArithmeticOptimizerContext& ctx_ext)
1921       : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1922   ~FuseSquaredDiffStage() override = default;
1923 
IsSupported(const NodeDef * node) const1924   bool IsSupported(const NodeDef* node) const override {
1925     return IsSquare(*node);
1926   }
1927 
TrySimplify(NodeDef * node,string * simplified_node_name)1928   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1929     NodeDef* b;
1930     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1931     // Optimize only if base is a Sub whose output is not being consumed
1932     // elsewhere.
1933     if (IsSub(*b) && !IsInPreserveSet(*b) &&
1934         (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1935       // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1936       // invalid.
1937       const DataType type = GetDataTypeFromAttr(*b, "T");
1938       if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128)) return OkStatus();
1939       node->set_op("Identity");
1940       b->set_op("SquaredDifference");
1941       AddToOptimizationQueue(node);
1942       AddToOptimizationQueue(b);
1943     }
1944     return OkStatus();
1945   }
1946 };
1947 
1948 // Performs the conversion:
1949 // Log(Softmax(x)) => LogSoftmax(x)
1950 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1951  public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1952   explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1953                            const ArithmeticOptimizerContext& ctx_ext)
1954       : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1955   ~LogSoftmaxStage() override = default;
1956 
IsSupported(const NodeDef * node) const1957   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1958 
TrySimplify(NodeDef * node,string * simplified_node_name)1959   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1960     NodeDef* x;
1961     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1962     // Optimize only if arg is a Softmax whose output is not being consumed
1963     // elsewhere.
1964     if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1965         (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1966       // Log(Softmax(x)) => LogSoftmax(Identity(x))
1967       node->set_op("LogSoftmax");
1968       x->set_op("Identity");
1969       AddToOptimizationQueue(node);
1970       AddToOptimizationQueue(x);
1971     }
1972     return OkStatus();
1973   }
1974 };
1975 
1976 // Bypass redundant reshape nodes:
1977 //
1978 //   Reshape                    Reshape  <-+
1979 //      ^                                  |
1980 //      |                                  |
1981 //   Reshape       becomes      Reshape    |
1982 //      ^                                  |
1983 //      |                                  |
1984 //    input                      input  ---+
1985 //
1986 // Additionally,  Reshape and BroadcastTo nodes where the
1987 // input and target shapes are equal are bypassed.
1988 //
1989 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1990  public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1991   explicit RemoveRedundantReshapeOrBroadcastTo(
1992       const GraphOptimizerContext& ctx,
1993       const ArithmeticOptimizerContext& ctx_ext)
1994       : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1995                                  ctx_ext) {}
1996   ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1997 
IsSupported(const NodeDef * node) const1998   bool IsSupported(const NodeDef* node) const override {
1999     return IsReshape(*node) || IsBroadcastTo(*node);
2000   }
2001 
2002   // TODO(rmlarsen): Handle unary ops with multiple outputs.
TrySimplify(NodeDef * node,string * simplified_node_name)2003   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2004     // 1. If the reshape is a no-op, forward its input to its consumers, unless
2005     // it anchors a control dependency since we want to make sure that control
2006     // dependency is triggered.
2007     if (!IsInPreserveSet(*node) && InputMatchesTargetShape(*node) &&
2008         !HasControlInputs(*node)) {
2009       *simplified_node_name = node->input(0);
2010       return OkStatus();
2011     }
2012 
2013     // 2. Bypass reshape followed by reshape, possibly separated by a simple
2014     // chain of unary elementwise ops that are not outputs.
2015     if (IsReshape(*node)) {
2016       bool skip = false;
2017       gtl::InlinedVector<const NodeDef*, 4> nodes_in_chain;
2018       const auto predicate_fn = [this, node, &skip,
2019                                  &nodes_in_chain](const NodeDef& input) {
2020         nodes_in_chain.push_back(&input);
2021         if ((input.name() != node->name() &&
2022              NumNonControlOutputs(input, *ctx().node_map) > 1) ||
2023             IsInPreserveSet(input) || ModifiesFrameInfo(input)) {
2024           skip = true;
2025           return false;
2026         }
2027         return IsUnaryElementWise(input);
2028       };
2029 
2030       // Walk up the input chain until we find a node that is not unary
2031       // element-wise. If it is another Reshape node, we can bypass it.
2032       NodeDef* tail =
2033           GetTailOfChain(*node, *ctx().node_map,
2034                          /*follow_control_input*/ false, predicate_fn);
2035 
2036       if (!skip && tail != nullptr && !IsInPreserveSet(*tail)) {
2037         NodeDef* reshape_to_bypass;
2038         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &reshape_to_bypass));
2039         if (reshape_to_bypass == nullptr ||
2040             (!IsReshape(*reshape_to_bypass) ||
2041              NumNonControlOutputs(*reshape_to_bypass, *ctx().node_map) > 1 ||
2042              IsInPreserveSet(*reshape_to_bypass))) {
2043           return OkStatus();
2044         }
2045         // Clearing invalid shape inference results of nodes in chain.
2046         for (const NodeDef* node_in_chain : nodes_in_chain) {
2047           ctx().graph_properties->ClearInputProperties(node_in_chain->name());
2048           if (node_in_chain != node) {
2049             ctx().graph_properties->ClearOutputProperties(
2050                 node_in_chain->name());
2051           }
2052         }
2053         // We now have
2054         //    reshape_to_bypass -> tail -> ... -> node
2055         // where tail maybe equal to node.
2056         TF_RETURN_IF_ERROR(
2057             UpdateConsumers(reshape_to_bypass, reshape_to_bypass->input(0)));
2058         ForwardControlDependencies(tail, {reshape_to_bypass});
2059         // Change the bypassed reshape to NoOp.
2060         ReplaceWithNoOp(reshape_to_bypass, ctx());
2061         *simplified_node_name = node->name();
2062         return OkStatus();
2063       }
2064     }
2065 
2066     return OkStatus();
2067   }
2068 
2069  private:
2070   // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)2071   bool InputMatchesTargetShape(const NodeDef& reshape) {
2072     const OpInfo::TensorProperties* reshape_props;
2073     const OpInfo::TensorProperties* input_props;
2074     if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
2075         !GetTensorProperties(reshape.input(0), &input_props).ok()) {
2076       return false;
2077     }
2078 
2079     return ShapesSymbolicallyEqual(input_props->shape(),
2080                                    reshape_props->shape());
2081   }
2082 };
2083 
2084 // Reorder casting and value-preserving ops if beneficial.
2085 //
2086 // Original motivation: A common pattern after the layout optimizer is
2087 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2088 // is beneficial to reorder the cast and the transpose to make the transpose
2089 // process smaller amount of data. More generally, this optimization converts
2090 //   Op(Cast(tensor, dst_type))
2091 // to
2092 //   Cast(Op(tensor), dst_type)
2093 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2094 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2095 // this optimization converts
2096 //   Cast(Op(tensor), dst_type)
2097 // to
2098 //   Op(Cast(tensor, dst_type))
2099 // when sizeof(tensor.type) > sizeof(dst_type)
2100 //
2101 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2102  public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2103   explicit ReorderCastLikeAndValuePreserving(
2104       const GraphOptimizerContext& ctx,
2105       const ArithmeticOptimizerContext& ctx_ext)
2106       : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2107                                  ctx_ext) {}
2108   ~ReorderCastLikeAndValuePreserving() override = default;
2109 
IsSupported(const NodeDef * node) const2110   bool IsSupported(const NodeDef* node) const override {
2111     return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2112            !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2113            !IsControlFlow(*node) && !IsInPreserveSet(*node);
2114   }
2115 
TrySimplify(NodeDef * consumer,string * simplified_node_name)2116   Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2117     NodeDef* producer;
2118 
2119     if (consumer->input_size() < 1) {
2120       return errors::FailedPrecondition("Node ", simplified_node_name,
2121                                         " lacks inputs");
2122     }
2123 
2124     TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2125     const bool producer_is_cast = IsCastLike(*producer);
2126     const bool can_optimize =
2127         !IsCheckNumerics(*producer) &&
2128         ((producer_is_cast && IsValuePreserving(*consumer)) ||
2129          (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2130     if (!can_optimize || IsControlFlow(*producer) ||
2131         IsInPreserveSet(*producer) ||
2132         producer->device() != consumer->device()) {
2133       return OkStatus();
2134     }
2135 
2136     const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2137     const OpDef* cast_like_op_def = nullptr;
2138     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2139                                                          &cast_like_op_def));
2140     DataType cast_src_type;
2141     TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2142                                         &cast_src_type));
2143     DataType cast_dst_type;
2144     TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2145                                          &cast_dst_type));
2146     if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2147       return OkStatus();
2148     } else if (producer_is_cast &&
2149                DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2150       return OkStatus();
2151     } else if (!producer_is_cast &&
2152                DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2153       return OkStatus();
2154     }
2155 
2156     // Check that nodes were not already optimized.
2157     const string optimized_producer_name = OptimizedNodeName(
2158         ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2159     const string optimized_consumer_name = OptimizedNodeName(
2160         ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2161     const bool is_already_optimized =
2162         ctx().node_map->NodeExists(optimized_consumer_name) ||
2163         ctx().node_map->NodeExists(optimized_producer_name);
2164     if (is_already_optimized) {
2165       return OkStatus();
2166     }
2167 
2168     // Add copies of consumer and producer in reverse order.
2169     NodeDef* input;
2170     TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2171     // Create new producer node.
2172     NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2173     new_producer->set_input(0, producer->input(0));
2174     ctx().node_map->AddOutput(input->name(), new_producer->name());
2175 
2176     // Create new consumer node.
2177     NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2178     new_consumer->set_input(0, new_producer->name());
2179 
2180     NodeDef* new_value_preserving =
2181         producer_is_cast ? new_producer : new_consumer;
2182     const DataType new_input_type =
2183         producer_is_cast ? cast_src_type : cast_dst_type;
2184     // Update the input type of the value-preserving node. The input and
2185     // output types of the cast-like nodes remain the same.
2186     TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2187     // Make sure there is a kernel registered for the value preserving op
2188     // with the new input type.
2189     TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2190     ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2191 
2192     AddToOptimizationQueue(new_producer);
2193     *simplified_node_name = new_consumer->name();
2194 
2195     return OkStatus();
2196   }
2197 
2198  private:
2199   // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2200   Status SetInputType(DataType dtype, NodeDef* node) {
2201     const OpDef* op_def = nullptr;
2202     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2203     const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2204     const string& type_attr_name = input_arg.type_attr();
2205     if (type_attr_name.empty()) {
2206       if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2207         return errors::InvalidArgument("Could not set input type of ",
2208                                        node->op(), " op to ",
2209                                        DataTypeString(dtype));
2210       } else {
2211         // Op has fixed input type that already matches dtype.
2212         return OkStatus();
2213       }
2214     }
2215     SetDataTypeToAttr(dtype, type_attr_name, node);
2216     return OkStatus();
2217   }
2218   // This optimization can be dangerous on devices other than CPU and
2219   // GPU. The transpose might not be implemented for image.type, or
2220   // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2221   bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2222     using absl::StrContains;
2223 
2224     string task;
2225     string device;
2226 
2227     return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2228            (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2229   }
2230 
IsFixedSizeType(DataType dtype)2231   bool IsFixedSizeType(DataType dtype) {
2232     return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2233            !kQuantizedTypes.Contains(dtype);
2234   }
2235 };
2236 
2237 // Fold a multiply of a scalar into the following convolution. This folding
2238 // can jump across nodes that merely reorders data (such as reshape and
2239 // transpose). For example, we can optimize
2240 //
2241 //
2242 //         Conv2D                             Conv2D
2243 //        /      \                           /      \
2244 //    Transpose  weights*       ->     Transpose    Mul
2245 //       |                                |        /   \
2246 //      Mul                               |    weights  scale
2247 //     /   \                              |
2248 //   input  scale**                     input
2249 //
2250 //  *) weights must be a const
2251 // **) scale must be a const scalar
2252 //
2253 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2254 // constant-folded, also weights tend to be smaller than the activations.
2255 //
2256 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2257 // Conv?DBackpropInput.
2258 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2259  public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2260   explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2261                                 const ArithmeticOptimizerContext& ctx_ext)
2262       : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2263   ~FoldMultiplyIntoConv() override = default;
2264 
IsSupported(const NodeDef * node) const2265   bool IsSupported(const NodeDef* node) const override {
2266     return IsConv2D(*node) || IsConv3D(*node);
2267   }
2268 
TrySimplify(NodeDef * node,string * simplified_node_name)2269   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2270 #define TF_RETURN_IF_TRUE(...) \
2271   if ((__VA_ARGS__)) return OkStatus()
2272 
2273     NodeDef* conv = node;
2274 
2275     NodeDef* weights;
2276     TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2277 
2278     // Fold the multiply to conv only when the weights are constant, so the
2279     // multiply can be constant-folded.
2280     //
2281     // TODO(jingyue): When the weights aren't constant, this should also help
2282     // performance a bit and memory usage a lot, since the weights tend to be
2283     // smaller than the activations.
2284     TF_RETURN_IF_TRUE(!IsConstant(*weights));
2285 
2286     // Verify that this node was not already optimized.
2287     const string scaled_weights_node_name =
2288         OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2289                           strings::StrCat("scaled", "_", conv->name()));
2290 
2291     TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2292 
2293     // Find the tail of value preserving chain entering the Conv node.
2294     NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2295                                                   *ctx().nodes_to_preserve);
2296 
2297     NodeDef* source;
2298     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2299 
2300     // Check that value preserving chain is the only consumer of the Mul output.
2301     TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2302     TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2303     // And that Mul is not in the preserve set.
2304     TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2305 
2306     const NodeDef* mul = source;
2307     int input_idx = 0;
2308     int scale_idx = 1;
2309     NodeDef* scale;  // scalar multiplier for the input tensor
2310     NodeDef* input;
2311     TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2312     TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2313     if (!IsConstant(*scale) && IsConstant(*input)) {
2314       VLOG(3) << "Swapped inputs to mul";
2315       std::swap(scale_idx, input_idx);
2316       std::swap(scale, input);
2317     }
2318     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2319 
2320     // Check that one of the inputs to mul is a constant scalar.
2321     const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2322     bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2323                              scale_tensor.tensor_shape().dim_size() == 0;
2324     TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2325 
2326     // Check that 'scale * weight' can be const folded.
2327     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2328     TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2329     TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2330     TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2331                       weights->attr().at("dtype").type());
2332 
2333     // At this point all preconditions are met, and we safely do the rewrite.
2334     VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2335             << " mul=" << mul->name() << " weights=" << weights->name();
2336 
2337     // Create new node `scaled_weights`.
2338     NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2339     scaled_weights->set_op(source->op());
2340     scaled_weights->set_device(weights->device());
2341     (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2342     AddToOptimizationQueue(scaled_weights);
2343 
2344     // Link in its inputs.
2345     scaled_weights->add_input(conv->input(1));
2346     ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2347     scaled_weights->add_input(mul->input(scale_idx));
2348     ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2349     ForwardControlDependencies(scaled_weights, {source});
2350 
2351     // Update `conv`'s weights to `scaled_weights`.
2352     conv->set_input(1, scaled_weights->name());
2353     ctx().node_map->UpdateInput(conv->name(), weights->name(),
2354                                 scaled_weights->name());
2355     AddToOptimizationQueue(conv);
2356 
2357     // Update `tail` node to bypass `mul` because it's folded to the weights.
2358     tail->set_input(0, mul->input(input_idx));
2359     ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2360     AddToOptimizationQueue(tail);
2361     *simplified_node_name = conv->name();
2362 
2363     return OkStatus();
2364 #undef TF_RETURN_IF_TRUE
2365   }
2366 };
2367 
2368 // Fold Transpose into matrix multiplication.
2369 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2370  public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2371   explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2372                                    const ArithmeticOptimizerContext& ctx_ext)
2373       : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2374   ~FoldTransposeIntoMatMul() override = default;
2375 
IsSupported(const NodeDef * node) const2376   bool IsSupported(const NodeDef* node) const override {
2377     return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2378   }
2379 
TrySimplify(NodeDef * node,string * simplified_node_name)2380   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2381     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2382     const string optimized_node_name = OptimizedNodeName(matmul);
2383     if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2384 
2385     NodeDef* a;
2386     NodeDef* b;
2387     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2388     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2389 
2390     bool is_complex = false;
2391     if (node->op() != "SparseMatMul") {
2392       const DataType type = GetDataTypeFromAttr(*node, "T");
2393       is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2394     }
2395 
2396     const std::set<string> foldable_transpose_ops =
2397         !is_complex
2398             ? std::set<string>{"ConjugateTranspose", "Transpose"}
2399             : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2400                                        : std::set<string>{"Transpose"});
2401 
2402     const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2403                                IsInnerMatrixTransposeNode(*a, ctx().node_map);
2404     const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2405                                IsInnerMatrixTransposeNode(*b, ctx().node_map);
2406     if (!a_is_foldable && !b_is_foldable) return OkStatus();
2407 
2408     NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2409 
2410     if (a_is_foldable) {
2411       const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2412       FlipBooleanAttr(attr_a, new_op);
2413       new_op->set_input(0, a->input(0));
2414       ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2415     } else {
2416       ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2417     }
2418 
2419     if (b_is_foldable) {
2420       const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2421       FlipBooleanAttr(attr_b, new_op);
2422       new_op->set_input(1, b->input(0));
2423       ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2424     } else {
2425       ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2426     }
2427 
2428     std::vector<const NodeDef*> deps_to_forward = {node};
2429     if (a_is_foldable) deps_to_forward.push_back(a);
2430     if (b_is_foldable) deps_to_forward.push_back(b);
2431     ForwardControlDependencies(new_op, deps_to_forward);
2432     *simplified_node_name = new_op->name();
2433 
2434     return OkStatus();
2435   }
2436 
2437  private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2438   void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2439     const bool old_value =
2440         !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2441     (*node->mutable_attr())[attr_name].set_b(!old_value);
2442   }
2443 
2444   template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2445   bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2446     const T n = perm.size();
2447     if (n < 2) {
2448       return false;
2449     }
2450     for (T i = 0; i < n - 2; ++i) {
2451       if (perm[i] != i) {
2452         return false;
2453       }
2454     }
2455     return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2456   }
2457 
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2458   bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2459                                   const NodeMap* node_map) {
2460     if (transpose_node.op() != "Transpose" &&
2461         transpose_node.op() != "ConjugateTranspose") {
2462       return false;
2463     }
2464     const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2465     std::vector<int> perm32;
2466     if (ValuesFromConstNode(*perm_node, &perm32)) {
2467       return IsInnerMatrixTranspose(perm32);
2468     }
2469     std::vector<int64_t> perm64;
2470     if (ValuesFromConstNode(*perm_node, &perm64)) {
2471       return IsInnerMatrixTranspose(perm64);
2472     }
2473     return false;
2474   }
2475 };
2476 
2477 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2478  public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2479   explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2480                                       const ArithmeticOptimizerContext& ctx_ext)
2481       : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2482   ~FoldConjugateIntoTranspose() override = default;
2483 
IsSupported(const NodeDef * node) const2484   bool IsSupported(const NodeDef* node) const override {
2485     return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2486   }
2487 
TrySimplify(NodeDef * node,string * simplified_node_name)2488   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2489     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2490     const string optimized_node_name = OptimizedNodeName(matmul);
2491     if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2492 
2493     NodeDef* input;
2494     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2495 
2496     const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2497     const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2498 
2499     if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2500         IsConj(*conj_op)) {
2501       NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2502 
2503       // Flip the type of transpose op to absorb the conjugation.
2504       new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2505                                                        : "Transpose");
2506       new_op->set_input(0, input->input(0));
2507       ctx().node_map->UpdateInput(new_op->name(), node->name(),
2508                                   input->input(0));
2509       ForwardControlDependencies(new_op, {node, input});
2510       *simplified_node_name = new_op->name();
2511     }
2512 
2513     return OkStatus();
2514   }
2515 };
2516 
2517 // Replace Mul node with identical inputs with a Square.
2518 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2519  public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2520   explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2521                                 const ArithmeticOptimizerContext& ctx_ext)
2522       : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2523   ~ReplaceMulWithSquare() override = default;
2524 
IsSupported(const NodeDef * node) const2525   bool IsSupported(const NodeDef* node) const override {
2526     if (!node || node->input_size() < 2) {
2527       // Invalid node
2528       return false;
2529     }
2530 
2531     return IsAnyMul(*node) && node->input(0) == node->input(1);
2532   }
2533 
TrySimplify(NodeDef * node,string * simplified_node_name)2534   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2535     const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2536     const string optimized_node_name = OptimizedNodeName(mul);
2537     if (ctx().node_map->NodeExists(optimized_node_name)) return OkStatus();
2538 
2539     const DataType type = GetDataTypeFromAttr(*node, "T");
2540     bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2541 
2542     if (!is_complex || NodeIsOnCpu(*node)) {
2543       NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2544       new_square_node->set_op("Square");
2545       for (int i = 1; i < new_square_node->input_size(); ++i) {
2546         new_square_node->set_input(i - 1, new_square_node->input(i));
2547       }
2548       new_square_node->mutable_input()->RemoveLast();
2549       for (const string& input : new_square_node->input()) {
2550         ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2551       }
2552       *simplified_node_name = new_square_node->name();
2553     }
2554 
2555     return OkStatus();
2556   }
2557 };
2558 
2559 // Replace a combination of Mul with broadcasting by Tile. E.g. replace
2560 //
2561 // input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
2562 // Ones (1x22x2x48x2x64) -^
2563 //
2564 // with
2565 //
2566 // input -> Tile(1x22x2x48x2x64) -> output
2567 class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
2568  public:
ReplaceMulWithBroadcastByTile(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2569   explicit ReplaceMulWithBroadcastByTile(
2570       const GraphOptimizerContext& ctx,
2571       const ArithmeticOptimizerContext& ctx_ext)
2572       : ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
2573                                  ctx_ext) {}
2574   ~ReplaceMulWithBroadcastByTile() override = default;
2575 
IsSupported(const NodeDef * node) const2576   bool IsSupported(const NodeDef* node) const override {
2577     return IsMul(*node) && !IsInPreserveSet(*node);
2578   }
2579 
TrySimplify(NodeDef * node,string * simplified_node_name)2580   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2581     NodeDef *input, *ones;
2582     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2583     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2584     if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
2585         IsInPreserveSet(*ones)) {
2586       return OkStatus();
2587     }
2588 
2589     // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
2590     if (IsConstant(*input) || !IsOnes(*ones)) return OkStatus();
2591 
2592     // Avoid optimizing the same node twice
2593     const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2594     const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
2595     const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
2596     if (ctx().node_map->NodeExists(tile_node_name) ||
2597         ctx().node_map->NodeExists(const_node_name)) {
2598       return OkStatus();
2599     }
2600 
2601     const std::vector<OpInfo::TensorProperties>& props =
2602         ctx().graph_properties->GetInputProperties(node->name());
2603     if (props.size() != 2) return OkStatus();
2604 
2605     // Ignore ops where the shape doesn't change
2606     const TensorShapeProto& input_shape = props[0].shape();
2607     const TensorShapeProto& ones_shape = props[1].shape();
2608     TensorShapeProto output_shape;
2609     if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
2610       return OkStatus();
2611     }
2612     if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
2613       return OkStatus();
2614     }
2615 
2616     // All inputs must have same input/output dimensions
2617     if (input_shape.dim_size() != output_shape.dim_size() ||
2618         ones_shape.dim_size() != output_shape.dim_size())
2619       return OkStatus();
2620 
2621     // At this point all preconditions are met. Can proceed with rewrite.
2622     VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
2623             << "@" << output_shape << " ones=" << ones->name() << "@"
2624             << ones_shape << " input=" << input->name() << "@" << input_shape;
2625 
2626     // 1. Create constant node with correct tile multiples
2627     Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
2628     for (int i = 0; i < output_shape.dim_size(); ++i) {
2629       int64_t size = output_shape.dim(i).size() / input_shape.dim(i).size();
2630       if (TF_PREDICT_FALSE(size >= INT_MAX)) {
2631         return Status(error::OUT_OF_RANGE, "int32 overflow");
2632       }
2633       multiples.flat<int32>()(i) = static_cast<int32>(size);
2634     }
2635 
2636     NodeDef* const_node = AddEmptyNode(const_node_name);
2637     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2638         const_node->name(), TensorValue(&multiples), const_node));
2639     const_node->set_device(node->device());
2640     ForwardControlDependencies(const_node, {ones});
2641     AddToOptimizationQueue(const_node);
2642 
2643     // 2. Replace multiply node with Tile(Const, input);
2644     const DataType type = GetDataTypeFromAttr(*node, "T");
2645     NodeDef* tile_node = AddEmptyNode(tile_node_name);
2646     tile_node->set_op("Tile");
2647     tile_node->set_device(node->device());
2648     SetDataTypeToAttr(type, "T", tile_node);
2649     SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
2650     tile_node->add_input(input->name());
2651     tile_node->add_input(const_node->name());
2652 
2653     ForwardControlDependencies(tile_node, {node});
2654     *simplified_node_name = tile_node->name();
2655 
2656     return OkStatus();
2657   }
2658 
2659  protected:
IsOnes(const NodeDef & node) const2660   bool IsOnes(const NodeDef& node) const {
2661     if (!IsReallyConstant(node)) return false;
2662     if (node.attr().at("dtype").type() != DT_FLOAT) return false;
2663 
2664     Tensor tensor;
2665     if (!tensor.FromProto(node.attr().at("value").tensor())) {
2666       return false;
2667     }
2668 
2669     auto values = tensor.flat<float>();
2670     for (int i = 0; i < tensor.NumElements(); ++i) {
2671       if (values(i) != 1.0f) {
2672         return false;
2673       }
2674     }
2675 
2676     return true;
2677   }
2678 };
2679 
2680 // Image upsampling often produces an unnecessary reshape that is difficult to
2681 // eliminate in other stages. This stage reduces the number of dimensions
2682 // involved allowing the reshape to be removed.
2683 //
2684 // For example, given
2685 //   B,W,H,C -> Reshape(B,W,1,H,1,C) -> Tile(1,1,2,1,2,1) -> Reshape(B,2W,2H,C)
2686 // this pass converts the sequence to
2687 //   B,W,H,C -> Reshape(B,W,H,C) -> Tile(1,1,2,2) -> Reshape(B,2W,2H,C)
2688 //
2689 // The first reshape is now redundant and can be removed in a later pass.
2690 //
2691 // Note: This only optimizes the simple (but extremely common) case of 2D
2692 // upsampling.
2693 //
2694 // TODO(kkiningh): Generalize to more complex upsampling patterns.
2695 class ReduceUpsamplingDims : public ArithmeticOptimizerStage {
2696  public:
ReduceUpsamplingDims(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2697   explicit ReduceUpsamplingDims(const GraphOptimizerContext& ctx,
2698                                 const ArithmeticOptimizerContext& ctx_ext)
2699       : ArithmeticOptimizerStage("ReduceUpsamplingDims", ctx, ctx_ext) {}
2700   ~ReduceUpsamplingDims() override = default;
2701 
IsSupported(const NodeDef * node) const2702   bool IsSupported(const NodeDef* node) const override {
2703     return IsReshape(*node) && !IsInPreserveSet(*node);
2704   }
2705 
TrySimplify(NodeDef * node,string * simplified_node_name)2706   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2707     NodeDef* tile;
2708     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &tile));
2709     if (!IsTile(*tile) || IsInPreserveSet(*tile)) {
2710       return OkStatus();
2711     }
2712 
2713     if (NumNonControlOutputs(*tile, *ctx().node_map) != 1) {
2714       // Optimization is only worthwile when there is a single output from Tile.
2715       // Otherwise, we need to insert additional Reshape ops that can't be
2716       // easily removed.
2717       return OkStatus();
2718     }
2719 
2720     NodeDef* reshape;
2721     TF_RETURN_IF_ERROR(GetInputNode(tile->input(0), &reshape));
2722     if (!IsReshape(*reshape) || IsInPreserveSet(*reshape)) {
2723       return OkStatus();
2724     }
2725 
2726     NodeDef* multiples;
2727     TF_RETURN_IF_ERROR(GetInputNode(tile->input(1), &multiples));
2728 
2729     NodeDef* shape;
2730     TF_RETURN_IF_ERROR(GetInputNode(reshape->input(1), &shape));
2731 
2732     // Avoid optimizing the same nodes twice
2733     const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2734     const string new_reshape_name =
2735         OptimizedNodeName(scope_and_name, "Reshape");
2736     const string new_tile_name = OptimizedNodeName(scope_and_name, "Tile");
2737     const string new_multiples_name =
2738         OptimizedNodeName(scope_and_name, "Multiples");
2739     const string new_shape_name = OptimizedNodeName(scope_and_name, "Shape");
2740     if (ctx().node_map->NodeExists(new_reshape_name) ||
2741         ctx().node_map->NodeExists(new_tile_name) ||
2742         ctx().node_map->NodeExists(new_shape_name) ||
2743         ctx().node_map->NodeExists(new_multiples_name)) {
2744       return OkStatus();
2745     }
2746 
2747     // Compuate updated multiples/shape values.
2748     AttrValue new_multiples_attr;
2749     if (!CreateUpdatedMultiplesProto(multiples,
2750                                      new_multiples_attr.mutable_tensor())) {
2751       return OkStatus();
2752     }
2753     AttrValue new_shape_attr;
2754     if (!CreateUpdatedShapeProto(shape, new_shape_attr.mutable_tensor())) {
2755       return OkStatus();
2756     }
2757 
2758     // At this point the graph is validated and can be updated
2759     // Note: We can assume shape/multiples are DT_INT32 only at this point since
2760     // they're checked in CreateUpdated*Proto()
2761 
2762     // 1. Create the constant nodes used by the new Reshape/Tile nodes
2763     NodeDef* new_multiples = AddEmptyNode(new_multiples_name);
2764     new_multiples->set_op("Const");
2765     SetDataTypeToAttr(DT_INT32, "dtype", new_multiples);
2766     new_multiples->mutable_attr()->insert({"value", new_multiples_attr});
2767     new_multiples->set_device(multiples->device());
2768 
2769     NodeDef* new_shape = AddEmptyNode(new_shape_name);
2770     new_shape->set_op("Const");
2771     SetDataTypeToAttr(DT_INT32, "dtype", new_shape);
2772     new_shape->mutable_attr()->insert({"value", new_shape_attr});
2773     new_shape->set_device(shape->device());
2774 
2775     // 2. Create the new Reshape/Tile nodes
2776     NodeDef* new_reshape = AddEmptyNode(new_reshape_name);
2777     CopyReshapeWithInput(reshape, new_reshape, /*input=*/reshape->input(0),
2778                          /*shape=*/new_shape->name());
2779     NodeDef* new_tile = AddEmptyNode(new_tile_name);
2780     CopyTileWithInput(tile, new_tile, /*input=*/new_reshape->name(),
2781                       /*multiples=*/new_multiples->name());
2782 
2783     // 3. Update consumer of original Tile node and add control
2784     node->set_input(0, new_tile->name());
2785     ctx().node_map->UpdateInput(node->name(), tile->name(), new_tile->name());
2786 
2787     ForwardControlDependencies(new_tile, {tile});
2788     ForwardControlDependencies(new_multiples, {multiples});
2789     ForwardControlDependencies(new_reshape, {reshape});
2790     ForwardControlDependencies(new_shape, {shape});
2791 
2792     *simplified_node_name = node->name();
2793     return OkStatus();
2794   }
2795 
2796  private:
CreateUpdatedMultiplesProto(const NodeDef * node,TensorProto * proto)2797   bool CreateUpdatedMultiplesProto(const NodeDef* node, TensorProto* proto) {
2798     Tensor multiples;
2799     if (!GetTensorFromConstNode(node->name(), &multiples)) {
2800       return false;
2801     }
2802 
2803     // Dimensions should be [X, Y, N, 1, M, 1]
2804     if (multiples.dtype() != DT_INT32 || multiples.NumElements() != 6) {
2805       return false;
2806     }
2807 
2808     const auto& multiples_values = multiples.flat<int32>();
2809     if (multiples_values(3) != 1 || multiples_values(5) != 1) {
2810       return false;
2811     }
2812 
2813     // Convert to [X, Y, N, M]
2814     Tensor new_multiples(DT_INT32, {4});
2815     new_multiples.flat<int32>()(0) = multiples_values(0);
2816     new_multiples.flat<int32>()(1) = multiples_values(1);
2817     new_multiples.flat<int32>()(2) = multiples_values(2);
2818     new_multiples.flat<int32>()(3) = multiples_values(4);
2819 
2820     new_multiples.AsProtoTensorContent(proto);
2821     return true;
2822   }
2823 
CreateUpdatedShapeProto(const NodeDef * node,TensorProto * proto)2824   bool CreateUpdatedShapeProto(const NodeDef* node, TensorProto* proto) {
2825     Tensor shape;
2826     if (!GetTensorFromConstNode(node->name(), &shape)) {
2827       return false;
2828     }
2829 
2830     // Dimensions should be [B, W, 1, H, 1, C]
2831     if (shape.dtype() != DT_INT32 || shape.NumElements() != 6) {
2832       return false;
2833     }
2834 
2835     const auto& shape_values = shape.flat<int32>();
2836     if (shape_values(2) != 1 || shape_values(4) != 1) {
2837       return false;
2838     }
2839 
2840     // Convert to [B, W, H, C]
2841     Tensor new_shape(DT_INT32, {4});
2842     new_shape.flat<int32>()(0) = shape_values(0);
2843     new_shape.flat<int32>()(1) = shape_values(1);
2844     new_shape.flat<int32>()(2) = shape_values(3);
2845     new_shape.flat<int32>()(3) = shape_values(5);
2846 
2847     new_shape.AsProtoTensorContent(proto);
2848     return true;
2849   }
2850 
CopyReshapeWithInput(const NodeDef * reshape,NodeDef * new_reshape,const string & input,const string & shape)2851   void CopyReshapeWithInput(const NodeDef* reshape, NodeDef* new_reshape,
2852                             const string& input, const string& shape) {
2853     new_reshape->set_op("Reshape");
2854     new_reshape->set_device(reshape->device());
2855     SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "T"), "T", new_reshape);
2856     SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "Tshape"), "Tshape",
2857                       new_reshape);
2858 
2859     new_reshape->add_input(input);
2860     ctx().node_map->AddOutput(NodeName(input), new_reshape->name());
2861     new_reshape->add_input(shape);
2862     ctx().node_map->AddOutput(NodeName(shape), new_reshape->name());
2863 
2864     AddToOptimizationQueue(new_reshape);
2865   }
2866 
CopyTileWithInput(const NodeDef * tile,NodeDef * new_tile,const string & input,const string & multiples)2867   void CopyTileWithInput(const NodeDef* tile, NodeDef* new_tile,
2868                          const string& input, const string& multiples) {
2869     new_tile->set_op("Tile");
2870     new_tile->set_device(tile->device());
2871     SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "T"), "T", new_tile);
2872     SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "Tmultiples"), "Tmultiples",
2873                       new_tile);
2874 
2875     new_tile->add_input(input);
2876     ctx().node_map->AddOutput(NodeName(input), new_tile->name());
2877     new_tile->add_input(multiples);
2878     ctx().node_map->AddOutput(NodeName(multiples), new_tile->name());
2879 
2880     AddToOptimizationQueue(new_tile);
2881   }
2882 };
2883 
2884 // Replace a sequence of Pack nodes with identical inputs with Tile
2885 // For example, given a Tensor X with shape (I,J,K)
2886 // Let P(x, n) = Pack([x, x], axis=n)
2887 //
2888 // P(P(X, 2), 1)
2889 //   = Tile(Reshape(Tile(Reshape(x,
2890 //              [I,    J, 1, K]), [1,    1, 2, 1]),
2891 //              [I, 1, J, 2, K]), [1, 2, 1, 1, 1]))
2892 //   = Tile(Reshape(x,
2893 //              [I, 1, J, 1, K]), [1, 2, 1, 2, 1])
2894 //   = Reshape(Tile(x, [1, 2, 2]), [I, 2, J, 2, K])
2895 //
2896 // The outermost reshape is often redundant and can be removed in another pass
2897 class ReplacePackWithTileReshape : public ArithmeticOptimizerStage {
2898  public:
ReplacePackWithTileReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2899   explicit ReplacePackWithTileReshape(const GraphOptimizerContext& ctx,
2900                                       const ArithmeticOptimizerContext& ctx_ext)
2901       : ArithmeticOptimizerStage("ReplacePackWithTileReshape", ctx, ctx_ext) {}
2902   ~ReplacePackWithTileReshape() override = default;
2903 
IsSupported(const NodeDef * node) const2904   bool IsSupported(const NodeDef* node) const override {
2905     return IsPack(*node) && NumNonControlInputs(*node) > 1 &&
2906            !IsInPreserveSet(*node);
2907   }
2908 
TrySimplify(NodeDef * node,string * simplified_node_name)2909   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2910     // 1. traverse the chain of Pack ops to get the original input
2911     NodeDef* input = node;
2912     std::vector<const NodeDef*> chain;
2913     while (IsPack(*input) && NumNonControlInputs(*node) > 1 &&
2914            !IsInPreserveSet(*input)) {
2915       // Only pack operations with all identical inputs are supported
2916       if (!AllRegularInputsEqual(*input)) {
2917         break;
2918       }
2919       chain.push_back(input);
2920       TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &input));
2921     }
2922 
2923     // Must be at least two Pack operations to consider for replacement
2924     if (chain.empty()) {
2925       return OkStatus();
2926     }
2927 
2928     // Avoid optimizing the same node twice
2929     const NodeScopeAndName node_scope_and_name =
2930         ParseNodeScopeAndName(node->name());
2931     const string new_const_name =
2932         OptimizedNodeName(node_scope_and_name, "Multiples");
2933     const string new_tile_name = OptimizedNodeName(node_scope_and_name, "Tile");
2934     const string new_shape_name =
2935         OptimizedNodeName(node_scope_and_name, "Shape");
2936     const string new_reshape_name =
2937         OptimizedNodeName(node_scope_and_name, "Reshape");
2938     if (ctx().node_map->NodeExists(new_const_name) ||
2939         ctx().node_map->NodeExists(new_tile_name) ||
2940         ctx().node_map->NodeExists(new_shape_name) ||
2941         ctx().node_map->NodeExists(new_reshape_name)) {
2942       return OkStatus();
2943     }
2944 
2945     // 2. Calculate the multiples and shape tensor using the chain
2946     const OpInfo::TensorProperties* input_props;
2947     TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props));
2948     const TensorShapeProto& input_shape = input_props->shape();
2949     if (!PartialTensorShape(input_shape).IsFullyDefined()) {
2950       return OkStatus();
2951     }
2952     Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()}));
2953     TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples));
2954 
2955     const OpInfo::TensorProperties* output_props;
2956     TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props));
2957     const TensorShapeProto& output_shape = output_props->shape();
2958     if (!PartialTensorShape(output_shape).IsFullyDefined()) {
2959       return OkStatus();
2960     }
2961     Tensor output_shape_tensor(DT_INT32,
2962                                TensorShape({output_shape.dim_size()}));
2963     for (int i = 0; i < output_shape.dim_size(); ++i) {
2964       output_shape_tensor.flat<int32>()(i) = output_shape.dim(i).size();
2965     }
2966 
2967     // 3. Create constant node with correct multiples value
2968     NodeDef* new_const_node = AddEmptyNode(new_const_name);
2969     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2970         new_const_node->name(), TensorValue(&multiples), new_const_node));
2971     new_const_node->set_device(node->device());
2972     MaybeAddControlInput(input->name(), new_const_node, ctx().optimized_graph,
2973                          ctx().node_map);
2974     AddToOptimizationQueue(new_const_node);
2975 
2976     // 4. Replace the Pack node with Tile(Const(N), input);
2977     DataType dtype = GetDataTypeFromAttr(*node, "T");
2978     NodeDef* new_tile_node = AddEmptyNode(new_tile_name);
2979     new_tile_node->set_op("Tile");
2980     new_tile_node->set_device(node->device());
2981     SetDataTypeToAttr(dtype, "T", new_tile_node);
2982     SetDataTypeToAttr(DT_INT32, "Tmultiples", new_tile_node);
2983     new_tile_node->add_input(input->name());
2984     ctx().node_map->AddOutput(input->name(), new_tile_node->name());
2985     new_tile_node->add_input(new_const_node->name());
2986     ctx().node_map->AddOutput(new_const_node->name(), new_tile_node->name());
2987 
2988     // Tile inherits all control dependencies from the original pack chain
2989     ForwardControlDependencies(new_tile_node, chain);
2990     AddToOptimizationQueue(new_tile_node);
2991 
2992     // 5. Add a new Reshape node to preserve the existing shape
2993     NodeDef* new_shape_node = AddEmptyNode(new_shape_name);
2994     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2995         new_shape_node->name(), TensorValue(&output_shape_tensor),
2996         new_shape_node));
2997     new_shape_node->set_device(node->device());
2998     MaybeAddControlInput(input->name(), new_shape_node, ctx().optimized_graph,
2999                          ctx().node_map);
3000     AddToOptimizationQueue(new_shape_node);
3001 
3002     NodeDef* new_reshape_node = AddEmptyNode(new_reshape_name);
3003     new_reshape_node->set_op("Reshape");
3004     new_reshape_node->set_device(node->device());
3005     SetDataTypeToAttr(dtype, "T", new_reshape_node);
3006     SetDataTypeToAttr(DT_INT32, "Tshape", new_reshape_node);
3007     new_reshape_node->add_input(new_tile_node->name());
3008     ctx().node_map->AddOutput(new_tile_node->name(), new_reshape_node->name());
3009     new_reshape_node->add_input(new_shape_node->name());
3010     ctx().node_map->AddOutput(new_shape_node->name(), new_reshape_node->name());
3011 
3012     *simplified_node_name = new_reshape_node->name();
3013 
3014     return OkStatus();
3015   }
3016 
3017  protected:
CalculateMultiplesFromChain(const std::vector<const NodeDef * > & chain,Tensor * multiples)3018   Status CalculateMultiplesFromChain(const std::vector<const NodeDef*>& chain,
3019                                      Tensor* multiples) {
3020     // Keep track of how the multiples correspond to each shape dimension.
3021     // For example, given Stack([x, x], axis=1) with rank(x) = 3, we start with
3022     //    multiples=[1, 1, 1] , dims=[0, 1, 2]
3023     // After processing the stack op
3024     //    multiples=[1, 2, 1] , dims=[0, 1, 1, 2]
3025     std::vector<int32> dims(multiples->NumElements());
3026     std::iota(dims.begin(), dims.end(), 0);
3027 
3028     for (int i = 0; i < multiples->NumElements(); ++i) {
3029       multiples->flat<int32>()(i) = 1;
3030     }
3031 
3032     for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
3033       AttrSlice attrs(**it);
3034       int64_t axis, n;
3035       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &axis));
3036       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &n));
3037 
3038       if (axis >= dims.size()) {
3039         // We don't handle the case where Pack is performed on the last axis,
3040         // e.g. Pack([x, x], axis=3) where rank(x) == 3
3041         return Status(error::OUT_OF_RANGE, "axis value out of range of dims");
3042       }
3043 
3044       int64_t m = multiples->flat<int32>()(dims[axis]) * n;
3045       if (TF_PREDICT_FALSE(m > INT_MAX)) {
3046         return Status(error::OUT_OF_RANGE, "int32 overflow");
3047       }
3048       multiples->flat<int32>()(dims[axis]) = static_cast<int32>(m);
3049 
3050       // Copy index from immediate right of inserted axis
3051       dims.insert(dims.begin() + axis, dims[axis]);
3052     }
3053 
3054     return OkStatus();
3055   }
3056 };
3057 
3058 // Simplify aggregation (e.g. AddN) nodes:
3059 //
3060 // 1. Discard aggregate nodes with a single input and no control dependencies.
3061 //
3062 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
3063 //    deduping or other rewrites) so we can get rid of the sum entirely.
3064 //
3065 //    The expression (using AddN as an example of an aggregate op):
3066 //      AddN(x, x, x, ... ,x)
3067 //           <-- N terms -->
3068 //    can be rewritten to:
3069 //      Mul(Const(N), x))
3070 //
3071 class SimplifyAggregation : public ArithmeticOptimizerStage {
3072  public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3073   explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
3074                                const ArithmeticOptimizerContext& ctx_ext)
3075       : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
3076   ~SimplifyAggregation() override = default;
3077 
IsSupported(const NodeDef * node) const3078   bool IsSupported(const NodeDef* node) const override {
3079     return IsAggregate(*node) && HasRegularInputs(*node) &&
3080            GetDataTypeFromAttr(*node, "T") !=
3081                DT_VARIANT;  // TODO(b/119787146): Enable for variants.
3082   }
3083 
TrySimplify(NodeDef * node,string * simplified_node_name)3084   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3085     // 1. Discard aggregate nodes with a single input and no control deps.
3086     if (node->input_size() == 1) {
3087       *simplified_node_name = node->input(0);
3088       return OkStatus();
3089     }
3090 
3091     // 2. Rewrite aggregations of N >= 2 identical terms.
3092 
3093     // All non-control inputs must be identical.
3094     bool all_equal = true;
3095     int num_inputs = 1;
3096     for (int i = 1; i < node->input_size(); ++i) {
3097       if (IsControlInput(node->input(i))) break;
3098       ++num_inputs;
3099       if (node->input(i) != node->input(0)) {
3100         all_equal = false;
3101         break;
3102       }
3103     }
3104     if (!all_equal) return OkStatus();
3105 
3106     // And node should not be optimized earlier.
3107     const NodeScopeAndName node_scope_and_name =
3108         ParseNodeScopeAndName(node->name());
3109     const string optimized_const_name =
3110         OptimizedNodeName(node_scope_and_name, "Const");
3111     const string optimized_mul_name =
3112         OptimizedNodeName(node_scope_and_name, "Mul");
3113 
3114     bool is_already_optimized =
3115         ctx().node_map->NodeExists(optimized_const_name) ||
3116         ctx().node_map->NodeExists(optimized_mul_name);
3117 
3118     if (is_already_optimized) return OkStatus();
3119 
3120     // At this point all preconditions are met, and we safely do the rewrite.
3121     VLOG(3) << "Simplify aggregation with identical inputs: node="
3122             << node->name() << " num_inputs=" << num_inputs;
3123 
3124     // 1. Create constant node with value N.
3125     const auto type = GetDataTypeFromAttr(*node, "T");
3126     Tensor t(type, TensorShape({}));
3127     Status status = SetTensorValue(type, num_inputs, &t);
3128     if (!status.ok()) {
3129       return errors::Internal("Failed to create const node: ",
3130                               status.error_message());
3131     }
3132 
3133     TensorValue value(&t);
3134     NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
3135     status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
3136                                             new_const_node);
3137     if (!status.ok()) {
3138       return errors::Internal("Failed to create const node: ",
3139                               status.error_message());
3140     }
3141     new_const_node->set_device(node->device());
3142     MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
3143                          ctx().optimized_graph, ctx().node_map);
3144     AddToOptimizationQueue(new_const_node);
3145 
3146     // 2. Replace the aggregate node with Mul(Const(N), x).
3147     NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
3148     new_mul_node->set_op("Mul");
3149     new_mul_node->set_device(node->device());
3150     SetDataTypeToAttr(type, "T", new_mul_node);
3151     new_mul_node->add_input(new_const_node->name());
3152     ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
3153     new_mul_node->add_input(node->input(0));
3154     ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
3155 
3156     ForwardControlDependencies(new_mul_node, {node});
3157     *simplified_node_name = new_mul_node->name();
3158 
3159     return OkStatus();
3160   }
3161 };
3162 
3163 class ConvertPowStage : public ArithmeticOptimizerStage {
3164  public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3165   explicit ConvertPowStage(const GraphOptimizerContext& ctx,
3166                            const ArithmeticOptimizerContext& ctx_ext)
3167       : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
3168 
IsSupported(const NodeDef * node) const3169   bool IsSupported(const NodeDef* node) const override {
3170     return IsPow(*node) &&
3171            ctx().graph_properties->HasOutputProperties(node->name()) &&
3172            ctx().graph_properties->HasInputProperties(node->name());
3173   }
3174 
TrySimplify(NodeDef * node,string * simplified_node_name)3175   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3176     Tensor pow;
3177     if (!GetTensorFromConstNode(node->input(1), &pow)) return OkStatus();
3178     complex128 prev, curr;
3179     for (int i = 0; i < pow.NumElements(); ++i) {
3180       if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
3181         // input data type is not supported by Pow. Skip.
3182         return OkStatus();
3183       }
3184       if (i != 0 && curr != prev) {
3185         // pow has different values on different elements. Skip.
3186         return OkStatus();
3187       }
3188       prev = curr;
3189     }
3190     NodeDef *x, *y;
3191     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
3192     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
3193     const auto& value_props =
3194         ctx().graph_properties->GetInputProperties(node->name())[0];
3195     const TensorShapeProto& output_shape =
3196         ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
3197     if (curr == complex128(2, 0)) {
3198       node->set_op("Square");
3199       node->set_input(1, AsControlDependency(y->name()));
3200       AddToOptimizationQueue(node);
3201       AddToOptimizationQueue(y);
3202     } else if (curr == complex128(3, 0)) {
3203       // TODO(courbet): Use 'Cube' when it's added to TF ops.
3204       if (NodeIsOnCpu(*node)) {
3205         // We create an inner square node: inner_square = square(x)
3206         const NodeScopeAndName scope_and_name =
3207             ParseNodeScopeAndName(node->name());
3208         const string inner_square_name =
3209             OptimizedNodeName(scope_and_name, "_inner");
3210         NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
3211         if (inner_square_node == nullptr) {
3212           inner_square_node = AddCopyNode(inner_square_name, node);
3213           inner_square_node->set_op("Square");
3214           inner_square_node->mutable_input()->RemoveLast();
3215         }
3216         ctx().node_map->AddOutput(x->name(), inner_square_node->name());
3217         // We modify `node`: node = mul(x, inner_square);
3218         node->set_op("Mul");
3219         node->set_input(1, inner_square_node->name());
3220         node->add_input(AsControlDependency(y->name()));
3221 
3222         AddToOptimizationQueue(node);
3223         AddToOptimizationQueue(inner_square_node);
3224         AddToOptimizationQueue(y);
3225       }
3226     } else if (curr == complex128(1, 0) &&
3227                ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
3228       // Pow could be used to broadcast, so make sure the shapes of the two
3229       // arguments are identical before replacing Pow with Identity.
3230       node->set_op("Identity");
3231       node->set_input(1, AsControlDependency(y->name()));
3232       AddToOptimizationQueue(node);
3233       AddToOptimizationQueue(y);
3234     } else if (curr == complex128(0.5, 0)) {
3235       node->set_op("Sqrt");
3236       node->set_input(1, AsControlDependency(y->name()));
3237       AddToOptimizationQueue(node);
3238       AddToOptimizationQueue(y);
3239     } else if (curr == complex128(0, 0) &&
3240                ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
3241                PartialTensorShape(output_shape).IsFullyDefined()) {
3242       const auto dtype = node->attr().at("T").type();
3243       Tensor ones(dtype, output_shape);
3244       for (int i = 0; i < ones.NumElements(); ++i) {
3245         TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
3246       }
3247       node->set_op("Const");
3248       (*node->mutable_attr())["dtype"].set_type(dtype);
3249       node->mutable_attr()->erase("T");
3250       ones.AsProtoTensorContent(
3251           (*node->mutable_attr())["value"].mutable_tensor());
3252       node->set_input(0, AsControlDependency(x->name()));
3253       node->set_input(1, AsControlDependency(y->name()));
3254       AddToOptimizationQueue(node);
3255       AddToOptimizationQueue(x);
3256       AddToOptimizationQueue(y);
3257     } else if (curr == complex128(-0.5, 0)) {
3258       node->set_op("Rsqrt");
3259       node->set_input(1, AsControlDependency(y->name()));
3260       AddToOptimizationQueue(node);
3261       AddToOptimizationQueue(y);
3262     } else if (curr == complex128(-1, 0)) {
3263       node->set_op("Reciprocal");
3264       node->set_input(1, AsControlDependency(y->name()));
3265       AddToOptimizationQueue(node);
3266       AddToOptimizationQueue(y);
3267     }
3268     return OkStatus();
3269   }
3270 
3271  private:
SetElementToOne(int i,Tensor * t)3272   Status SetElementToOne(int i, Tensor* t) {
3273     switch (t->dtype()) {
3274       case DT_INT32:
3275         t->flat<int32>()(i) = 1;
3276         return OkStatus();
3277       case DT_INT64:
3278         t->flat<int64_t>()(i) = 1L;
3279         return OkStatus();
3280       case DT_FLOAT:
3281         t->flat<float>()(i) = 1.0f;
3282         return OkStatus();
3283       case DT_DOUBLE:
3284         t->flat<double>()(i) = 1.0;
3285         return OkStatus();
3286       case DT_COMPLEX64:
3287         t->flat<complex64>()(i) = complex64(1);
3288         return OkStatus();
3289       case DT_COMPLEX128:
3290         t->flat<complex128>()(i) = complex128(1);
3291         return OkStatus();
3292       default:
3293         return errors::InvalidArgument("Invalid data type: ", t->dtype());
3294     }
3295   }
3296 };
3297 
3298 class ConvertLog1pStage : public ArithmeticOptimizerStage {
3299  public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3300   explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
3301                              const ArithmeticOptimizerContext& ctx_ext)
3302       : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
3303   ~ConvertLog1pStage() override = default;
3304 
IsSupported(const NodeDef * node) const3305   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
3306 
TrySimplify(NodeDef * node,string * simplified_node_name)3307   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3308     NodeDef* input;
3309     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
3310     if (!IsAdd(*input)) {
3311       return OkStatus();
3312     }
3313 
3314     if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
3315       return OkStatus();
3316     }
3317 
3318     bool modified = false;
3319     TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
3320     if (!modified) {
3321       TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
3322     }
3323     if (modified) {
3324       *simplified_node_name = node->name();
3325     }
3326     return OkStatus();
3327   }
3328 
3329  private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)3330   Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
3331                              bool* modified) {
3332     const auto& t =
3333         ctx().graph_properties->GetInputProperties(add_node->name())[i];
3334     const auto& c =
3335         ctx().graph_properties->GetInputProperties(add_node->name())[j];
3336     for (int k = 0; k < c.shape().dim_size(); ++k) {
3337       // Skip if c shape is not fully determined.
3338       if (c.shape().dim(k).size() < 0) {
3339         return OkStatus();
3340       }
3341     }
3342     TensorShapeProto broadcast_shape;
3343     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3344       return OkStatus();
3345     }
3346     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3347       // skip if the non-constant tensor doesn't have the same shape after
3348       // broadcast.
3349       return OkStatus();
3350     }
3351     Tensor constant;
3352     if (GetTensorFromConstNode(add_node->input(j), &constant)) {
3353       complex128 element;
3354       // TODO(rmlarsen): Refactor the more general IsOnes from
3355       // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
3356       // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
3357       // and Neg(-1) to 1.
3358       for (int k = 0; k < constant.NumElements(); ++k) {
3359         if (!GetElementUnexhaustive(constant, k,
3360                                     {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3361                                      DT_COMPLEX64, DT_COMPLEX128},
3362                                     &element)) {
3363           // input data type is not supported by log1p. Skip.
3364           return OkStatus();
3365         }
3366         if (element != complex128(1)) {
3367           // current element is not 1. Skip.
3368           return OkStatus();
3369         }
3370       }
3371       NodeDef *x, *y;
3372       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
3373       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
3374       node->set_op("Log1p");
3375       node->set_input(0, add_node->input(i));
3376       node->add_input(AsControlDependency(y->name()));
3377       ForwardControlDependencies(node, {add_node});
3378 
3379       AddToOptimizationQueue(node);
3380       AddToOptimizationQueue(add_node);
3381       AddToOptimizationQueue(x);
3382       AddToOptimizationQueue(y);
3383       *modified = true;
3384     }
3385     return OkStatus();
3386   }
3387 };
3388 
3389 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
3390  public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3391   explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
3392                              const ArithmeticOptimizerContext& ctx_ext)
3393       : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
3394   ~ConvertExpm1Stage() override = default;
3395 
IsSupported(const NodeDef * node) const3396   bool IsSupported(const NodeDef* node) const override {
3397     if (!IsSub(*node)) return false;
3398 
3399     NodeDef* input;
3400     if (!GetInputNode(node->input(0), &input).ok()) return false;
3401 
3402     return IsExp(*input);
3403   }
3404 
TrySimplify(NodeDef * node,string * simplified_node_name)3405   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3406     if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
3407       return OkStatus();
3408     }
3409     const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
3410     const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
3411     TensorShapeProto broadcast_shape;
3412     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3413       return OkStatus();
3414     }
3415     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3416       // skip if the non-constant tensor doesn't have the same shape after
3417       // broadcast.
3418       return OkStatus();
3419     }
3420     Tensor constant;
3421     if (!GetTensorFromConstNode(node->input(1), &constant)) return OkStatus();
3422     // TODO(rmlarsen): Use the more general IsOnes helper here.
3423     complex128 element;
3424     for (int k = 0; k < constant.NumElements(); ++k) {
3425       if (!GetElementUnexhaustive(constant, k,
3426                                   {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3427                                    DT_COMPLEX64, DT_COMPLEX128},
3428                                   &element)) {
3429         // input data type is not supported by expm1. Skip.
3430         return OkStatus();
3431       }
3432       if (element != complex128(1)) {
3433         // current element is not 1. Skip.
3434         return OkStatus();
3435       }
3436     }
3437     NodeDef* exp;
3438     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
3439     NodeDef *exp_input, *ones;
3440     TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
3441     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
3442     node->set_op("Expm1");
3443     node->set_input(0, exp->input(0));
3444     node->set_input(1, AsControlDependency(ones->name()));
3445     ForwardControlDependencies(node, {exp});
3446 
3447     AddToOptimizationQueue(node);
3448     AddToOptimizationQueue(exp);
3449     AddToOptimizationQueue(exp_input);
3450     AddToOptimizationQueue(ones);
3451     *simplified_node_name = node->name();
3452     return OkStatus();
3453   }
3454 };
3455 
3456 // Performs conversions like:
3457 // Max(Sqrt(x)) => Sqrt(Max(x))
3458 // Checks for a max/min reduction over element-wise monotonic functions, such
3459 // as Sqrt, Sigmoid, Tanh, etc.
3460 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
3461  public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3462   explicit OptimizeMaxOrMinOfMonotonicStage(
3463       const GraphOptimizerContext& ctx,
3464       const ArithmeticOptimizerContext& ctx_ext)
3465       : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
3466                                  ctx_ext) {}
3467   ~OptimizeMaxOrMinOfMonotonicStage() override = default;
3468 
IsSupported(const NodeDef * node) const3469   bool IsSupported(const NodeDef* node) const override {
3470     // Running on (Unsorted)SegmentMax(Min) can cause issues on empty segments.
3471     return IsMax(*node) || IsMin(*node) || IsAnyMaxPool(*node) ||
3472            IsArgMax(*node) || IsArgMin(*node);
3473   }
3474 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3475   Status TrySimplify(NodeDef* reduction_node,
3476                      string* simplified_node_name) override {
3477     if (IsInPreserveSet(*reduction_node)) {
3478       return OkStatus();
3479     }
3480 
3481     NodeDef* inner_function;
3482     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
3483 
3484     NodeDef* inner_function_input = nullptr;
3485     if (inner_function->input_size() > 0) {
3486       TF_RETURN_IF_ERROR(
3487           GetInputNode(inner_function->input(0), &inner_function_input));
3488     }
3489 
3490     // Optimize only if:
3491     // 0. inner_function is not in the preserve set,
3492     // 1. inner_function's Op is element-wise monotonic
3493     // 2. inner_function's output is not being consumed elsewhere.
3494     // 3. is monotonic increasing if reduction_node is a pooling operation
3495     //    since we don't have MinPool operations.
3496     // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
3497     //    or BiasAdd. This pattern will be fused later by remapper.
3498     auto can_be_fused_by_remapper = [](const NodeDef& consumer,
3499                                        const NodeDef& producer) -> bool {
3500       if (IsRelu(consumer) || IsRelu6(consumer)) {
3501         if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
3502           return true;
3503         }
3504       }
3505       return false;
3506     };
3507     bool is_non_decreasing = false;
3508     if (!IsInPreserveSet(*inner_function) &&
3509         IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
3510         ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
3511         (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
3512         !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
3513       // Swap the first inputs of the inner function Op & the reduction Op.
3514       NodeDef* inner_input;
3515       TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
3516       reduction_node->set_input(0, inner_input->name());
3517       ctx().node_map->UpdateInput(reduction_node->name(),
3518                                   inner_function->name(), inner_input->name());
3519       inner_function->set_input(0, reduction_node->name());
3520       TF_RETURN_IF_ERROR(
3521           UpdateConsumers(reduction_node, inner_function->name()));
3522       ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
3523                                   reduction_node->name());
3524       if (!is_non_decreasing) {
3525         // Flip Min<->Max if the function is non-increasing, e.g.
3526         // Max(Neg(x)) = Neg(Min(x)).
3527         const string opposite = FlipMinMax(*reduction_node);
3528         reduction_node->set_op(opposite);
3529       }
3530 
3531       if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
3532         // ArgMax(Sqrt(x)) = ArgMax(x)
3533         inner_function->set_op("Identity");
3534       }
3535 
3536       AddToOptimizationQueue(reduction_node);
3537       AddToOptimizationQueue(inner_function);
3538       AddToOptimizationQueue(inner_input);
3539     }
3540     return OkStatus();
3541   }
3542 
3543  private:
FlipMinMax(const NodeDef & node)3544   string FlipMinMax(const NodeDef& node) {
3545     const string& op = node.op();
3546     if (IsAnyMax(node) || IsArgMax(node)) {
3547       return str_util::StringReplace(op, "Max", "Min", false);
3548     } else {
3549       return str_util::StringReplace(op, "Min", "Max", false);
3550     }
3551   }
3552 };
3553 
3554 // Replace a chain of type&shape preserving unary ops with a
3555 // '_UnaryOpsComposition' node.
3556 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
3557 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
3558 class UnaryOpsComposition : public ArithmeticOptimizerStage {
3559  public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3560   explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
3561                                const ArithmeticOptimizerContext& ctx_ext)
3562       : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
3563     // WARN: This should be consistent with unary_ops_composition.cc.
3564     // clang-format off
3565     supported_ops_ = {// Ops defined via Eigen scalar ops.
3566                       {"Abs",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3567                       {"Acos",       {DT_FLOAT,          DT_DOUBLE}},
3568                       {"Acosh",      {DT_FLOAT,          DT_DOUBLE}},
3569                       {"Asin",       {DT_FLOAT,          DT_DOUBLE}},
3570                       {"Asinh",      {DT_FLOAT,          DT_DOUBLE}},
3571                       {"Atan",       {DT_FLOAT,          DT_DOUBLE}},
3572                       {"Atanh",      {DT_FLOAT,          DT_DOUBLE}},
3573                       {"Ceil",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3574                       {"Cos",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3575                       {"Cosh",       {DT_FLOAT,          DT_DOUBLE}},
3576                       {"Expm1",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3577                       {"Exp",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3578                       {"Floor",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3579                       {"Inv",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3580                       {"Log",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3581                       {"Log1p",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3582                       {"Neg",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3583                       {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3584                       {"Rint",       {DT_FLOAT,          DT_DOUBLE}},
3585                       {"Round",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3586                       {"Rsqrt",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3587                       {"Sigmoid",    {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3588                       {"Sin",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3589                       {"Sinh",       {DT_FLOAT,          DT_DOUBLE}},
3590                       {"Sqrt",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3591                       {"Square",     {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3592                       {"Tan",        {DT_FLOAT,          DT_DOUBLE}},
3593                       {"Tanh",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3594                       // Additional ops that are not part of the Eigen.
3595                       {"Elu",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3596                       {"Relu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3597                       {"Relu6",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3598                       {"Selu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3599     // clang-format on
3600   }
3601   ~UnaryOpsComposition() override = default;
3602 
IsSupported(const NodeDef * node) const3603   bool IsSupported(const NodeDef* node) const override {
3604     return CanOptimize(*node) &&
3605            // Check that this node was not already a root of a fused chain. If
3606            // graph optimization runs twice without pruning in between,
3607            // fused_nodes_ will not have this information.
3608            !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3609   }
3610 
TrySimplify(NodeDef * root,string * simplified_node_name)3611   Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3612     TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3613     DataType dtype = root->attr().at("T").type();
3614 
3615     // Keep a trace of all supported input nodes that can be fused together.
3616     std::vector<string> op_nodes = {root->name()};
3617     std::vector<string> op_names = {root->op()};
3618 
3619     // Check if we should follow input(0) while building an op composition.
3620     const auto predicate_fn = [&](const NodeDef& input) {
3621       if (input.name() == root->name()) return true;
3622 
3623       bool follow_input_node =
3624           dtype == GetDataTypeFromAttr(input, "T") &&
3625           NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3626           CanOptimize(input);
3627 
3628       if (follow_input_node) {
3629         op_nodes.push_back(input.name());
3630         op_names.push_back(input.op());
3631       }
3632 
3633       return follow_input_node;
3634     };
3635 
3636     NodeDef* last_op = GetTailOfChain(
3637         *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3638 
3639     // We were not able to find a chain that can be replaced.
3640     if (op_names.size() == 1) return OkStatus();
3641 
3642     // Do not add fused nodes to any other chain.
3643     std::for_each(op_nodes.begin(), op_nodes.end(),
3644                   [this](const string& name) { AddToFusedNodes(name); });
3645 
3646     // Reverse the trace to get correct composition computation order.
3647     std::reverse(op_names.begin(), op_names.end());
3648 
3649     VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3650             << absl::StrJoin(op_names, ", ") << "]";
3651 
3652     NodeDef* composition_node = ctx().optimized_graph->add_node();
3653     composition_node->set_name(OptimizedNodeName(*root));
3654     composition_node->set_op("_UnaryOpsComposition");
3655     composition_node->add_input(last_op->input(0));
3656     composition_node->set_device(root->device());
3657 
3658     auto attr = composition_node->mutable_attr();
3659     SetAttrValue(dtype, &(*attr)["T"]);
3660     SetAttrValue(op_names, &(*attr)["op_names"]);
3661 
3662     ctx().node_map->AddNode(composition_node->name(), composition_node);
3663     ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3664                               composition_node->name());
3665 
3666     *simplified_node_name = composition_node->name();
3667 
3668     return OkStatus();
3669   }
3670 
3671  private:
CanOptimize(const NodeDef & node) const3672   bool CanOptimize(const NodeDef& node) const {
3673     DataType dtype = GetDataTypeFromAttr(node, "T");
3674     if (!IsSupported(node.op(), dtype)) {
3675       return false;
3676     }
3677     if (IsInPreserveSet(node)) {
3678       return false;
3679     }
3680     if (!NodeIsOnCpu(node)) {
3681       return false;
3682     }
3683     if (NodeIsAlreadyFused(node)) {
3684       return false;
3685     }
3686     return !(IsDrivenByControlDependency(node) ||
3687              DrivesControlDependency(node));
3688   }
3689 
NodeIsAlreadyFused(const NodeDef & node) const3690   bool NodeIsAlreadyFused(const NodeDef& node) const {
3691     return fused_nodes_.count(node.name()) > 0;
3692   }
3693 
OptimizedNodeName(const NodeDef & node) const3694   string OptimizedNodeName(const NodeDef& node) const {
3695     return strings::StrCat(node.name(), "/unary_ops_composition");
3696   }
3697 
AddToFusedNodes(const string & name)3698   void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3699 
3700   // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3701   bool IsSupported(const string& op_name, DataType dtype) const {
3702     const auto it = supported_ops_.find(op_name);
3703     return it != supported_ops_.end() && it->second.count(dtype) > 0;
3704   }
3705 
3706   std::unordered_map<string, std::set<DataType>> supported_ops_;
3707   std::unordered_set<string> fused_nodes_;
3708 };
3709 
3710 // Replace operations of the form:
3711 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3712 // with
3713 //    a_i
3714 // when the strided slice index `i` is applied in the k'th axis.
3715 //
3716 // Similarly, replace operations of the form:
3717 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3718 // with
3719 //    expand_dims(a_i, axis=k)
3720 // where the slice operator can be StridedSlice or Slice.
3721 //
3722 // TODO(ebrevdo): Extend to also replace operations of the form
3723 //    concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3724 // with
3725 //    a_i,
3726 // when
3727 //    s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3728 // and slicing is in the k'th axis.
3729 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3730  public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3731   explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3732                                     const ArithmeticOptimizerContext& ctx_ext)
3733       : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3734                                  ctx_ext) {}
3735   ~RemoveStackSliceSameAxis() override = default;
3736 
IsSupported(const NodeDef * node) const3737   bool IsSupported(const NodeDef* node) const override {
3738     return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3739   }
3740 
TrySimplify(NodeDef * node,string * simplified_node_name)3741   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3742     // *node is a StridedSlice NodeDef.
3743     NodeDef* pack;
3744 
3745     // Get the input and see if it's a Pack op.
3746     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3747     if (!IsPack(*pack)) return OkStatus();
3748 
3749     bool return_early;
3750     PartialTensorShape pack_output_shape;
3751     int pack_axis;
3752     TF_RETURN_IF_ERROR(
3753         CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3754     if (return_early) return OkStatus();
3755 
3756     int64_t slice_start_value;
3757     bool found;
3758     bool must_expand_dims;
3759     TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3760                                     &slice_start_value, &found,
3761                                     &must_expand_dims));
3762     if (!found) return OkStatus();
3763 
3764     return RewriteGraph(node, pack, slice_start_value, pack_axis,
3765                         must_expand_dims, simplified_node_name);
3766   }
3767 
3768  protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3769   Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3770                      PartialTensorShape* pack_output_shape, int* pack_axis,
3771                      bool* return_early) {
3772     *return_early = true;
3773     TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3774 
3775     *pack_axis = pack->attr().at("axis").i();
3776     auto slice_properties =
3777         ctx().graph_properties->GetInputProperties(node->name());
3778     if (slice_properties.empty() ||
3779         slice_properties[0].shape().unknown_rank()) {
3780       return OkStatus();
3781     }
3782     *pack_output_shape = slice_properties[0].shape();
3783     const int pack_output_rank = pack_output_shape->dims();
3784     if (*pack_axis < 0) {
3785       *pack_axis += pack_output_rank;
3786     }
3787     if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3788       return errors::InvalidArgument(
3789           "Pack node (", pack->name(),
3790           ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3791     }
3792     *return_early = false;
3793     return OkStatus();
3794   }
3795 
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found,bool * must_expand_dims)3796   Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3797                       const PartialTensorShape& pack_output_shape,
3798                       int pack_axis, int64_t* slice_start_value, bool* found,
3799                       bool* must_expand_dims) {
3800     *found = false;
3801     if (IsSlice(*node)) {
3802       *must_expand_dims = true;
3803       return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3804                                 slice_start_value, found);
3805     } else {
3806       return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3807                                  slice_start_value, found, must_expand_dims);
3808     }
3809   }
3810 
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found)3811   Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3812                             const PartialTensorShape& pack_output_shape,
3813                             int pack_axis, int64_t* slice_start_value,
3814                             bool* found) {
3815     NodeDef* slice_begin;
3816     NodeDef* slice_size;
3817     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3818     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3819     for (const auto* n : {slice_begin, slice_size}) {
3820       if (!IsReallyConstant(*n)) return OkStatus();
3821     }
3822 
3823     Tensor slice_begin_t;
3824     Tensor slice_size_t;
3825     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3826     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3827       return OkStatus();
3828     }
3829     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3830     if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3831       return OkStatus();
3832     }
3833 
3834     auto copy_tensor_values_to_vector =
3835         [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3836           if (t.dtype() == DT_INT32) {
3837             auto t_flat = t.flat<int32>();
3838             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3839           } else if (t.dtype() == DT_INT64) {
3840             auto t_flat = t.flat<int64_t>();
3841             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3842           } else {
3843             return errors::InvalidArgument("Node ", node->name(),
3844                                            " has invalid type for Index attr: ",
3845                                            DataTypeString(t.dtype()));
3846           }
3847           return OkStatus();
3848         };
3849 
3850     gtl::InlinedVector<int64_t, 4> slice_begin_vec;
3851     gtl::InlinedVector<int64_t, 4> slice_size_vec;
3852     TF_RETURN_IF_ERROR(
3853         copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3854     TF_RETURN_IF_ERROR(
3855         copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3856 
3857     if (slice_begin_vec.size() != slice_size_vec.size()) {
3858       return errors::InvalidArgument("Node ", node->name(),
3859                                      " has mismatched lengths for begin (",
3860                                      slice_begin_vec.size(), ") and size (",
3861                                      slice_size_vec.size(), ") vectors.");
3862     }
3863     int slice_begin_vec_size = slice_begin_vec.size();
3864     if (!pack_output_shape.unknown_rank() &&
3865         slice_begin_vec_size != pack_output_shape.dims()) {
3866       return OkStatus();
3867     }
3868     if (pack_axis >= slice_begin_vec_size) {
3869       return errors::InvalidArgument(
3870           "Input to node ", node->name(), " had pack_axis ", pack_axis,
3871           " but rank was ", slice_begin_vec_size, ".");
3872     }
3873 
3874     *slice_start_value = slice_begin_vec[pack_axis];
3875     if (slice_size_vec[pack_axis] != 1) {
3876       // Not slicing a single value out.
3877       return OkStatus();
3878     }
3879 
3880     for (int i = 0; i < slice_begin_vec_size; ++i) {
3881       if (i != pack_axis) {
3882         if (slice_begin_vec[i] != 0 ||
3883             !(slice_size_vec[i] == -1 ||
3884               slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3885           // Not slicing on the same axis as the Pack op.
3886           return OkStatus();
3887         }
3888       }
3889     }
3890 
3891     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3892       return errors::InvalidArgument(
3893           "Node ", node->name(), " requested invalid slice index ",
3894           *slice_start_value, " on axis ", pack_axis,
3895           " from tensor of shape: ", pack_output_shape.DebugString());
3896     }
3897 
3898     *found = true;  // slice_start_value is valid.
3899     return OkStatus();
3900   }
3901 
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64_t * slice_start_value,bool * found,bool * must_expand_dims)3902   Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3903                              const PartialTensorShape& pack_output_shape,
3904                              int pack_axis, int64_t* slice_start_value,
3905                              bool* found, bool* must_expand_dims) {
3906     TF_RETURN_IF_ERROR(
3907         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3908                                 "new_axis_mask", "shrink_axis_mask"}));
3909 
3910     const int begin_mask = node->attr().at("begin_mask").i();
3911     const int end_mask = node->attr().at("end_mask").i();
3912     const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3913     const int new_axis_mask = node->attr().at("new_axis_mask").i();
3914     const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3915 
3916     // Check that the StridedSlice is one of these at pack_axis:
3917     //   [..., i, ...]
3918     //   [..., i:i+1, ...]
3919     //   [..., :1, ...]
3920     //   [..., -1:, ...]
3921     ///  [..., s_{pack_axis}-1:, ...]
3922     NodeDef* slice_begin;
3923     NodeDef* slice_end;
3924     NodeDef* slice_strides;
3925     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3926     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3927     TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3928 
3929     for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3930       if (!IsReallyConstant(*n)) return OkStatus();
3931     }
3932 
3933     Tensor slice_begin_t;
3934     Tensor slice_end_t;
3935     Tensor slice_strides_t;
3936 
3937     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3938     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3939       return OkStatus();
3940     }
3941     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3942     if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3943       return OkStatus();
3944     }
3945     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3946     if (!slice_strides_t.FromProto(
3947             slice_strides->attr().at("value").tensor())) {
3948       return OkStatus();
3949     }
3950     TensorShape processing_shape;
3951     TensorShape final_shape;
3952     bool is_identity;
3953     bool is_simple_slice;
3954     bool slice_dim0;
3955     gtl::InlinedVector<int64_t, 4> slice_begin_vec;
3956     gtl::InlinedVector<int64_t, 4> slice_end_vec;
3957     gtl::InlinedVector<int64_t, 4> slice_strides_vec;
3958     TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3959         &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3960         begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3961         &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3962         &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3963 
3964     if (!is_simple_slice) return OkStatus();
3965 
3966     int begin_index = -1;
3967     int64_t begin_value = 0;
3968     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3969       const int64_t v = slice_begin_vec[i];
3970       if (v != 0) {
3971         if (begin_index != -1) {
3972           // At least two start values that are nonzero.
3973           return OkStatus();
3974         }
3975         begin_index = i;
3976         begin_value = v;
3977       }
3978     }
3979 
3980     int end_index = -1;
3981     int64_t end_value = 0;
3982     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3983       const int64_t v = slice_end_vec[i];
3984       if (v != pack_output_shape.dim_size(i)) {
3985         if (end_index != -1) {
3986           // At least two end values that are nonzero.
3987           return OkStatus();
3988         }
3989         end_index = i;
3990         end_value = v;
3991       }
3992     }
3993 
3994     if (begin_index == -1 && end_index == -1) return OkStatus();
3995     if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3996       // Somehow received different axes for begin/end slicing
3997       return OkStatus();
3998     }
3999     const int slice_axis = (begin_index == -1) ? end_index : begin_index;
4000     if (slice_axis != pack_axis) {
4001       // Not slicing on the same axis as the Pack op.
4002       return OkStatus();
4003     }
4004     *slice_start_value = (begin_index == -1) ? 0 : begin_value;
4005     const int64_t slice_end_value =
4006         (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
4007     if (slice_end_value != *slice_start_value + 1) {
4008       // Not slicing a single value out.
4009       return OkStatus();
4010     }
4011 
4012     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
4013       return errors::InvalidArgument(
4014           "Node ", node->name(), " requested invalid slice index ",
4015           *slice_start_value, " on axis ", slice_axis,
4016           " from tensor of shape: ", pack_output_shape.DebugString());
4017     }
4018 
4019     if (shrink_axis_mask == 0) {
4020       *must_expand_dims = true;
4021     } else if (shrink_axis_mask == (1 << slice_axis)) {
4022       *must_expand_dims = false;
4023     } else {
4024       // Shrinking on a different axis from the one that we are slicing on.
4025       return OkStatus();
4026     }
4027 
4028     *found = true;  // slice_start_value is valid.
4029     return OkStatus();
4030   }
4031 
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64_t slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)4032   Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
4033                       int64_t slice_start_value, int pack_axis,
4034                       bool must_expand_dims, string* simplified_node_name) {
4035     const string& input_slice = pack->input(slice_start_value);
4036 
4037     const OpInfo::TensorProperties* input_slice_properties;
4038     TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
4039                                            &input_slice_properties));
4040     PartialTensorShape input_slice_shape(input_slice_properties->shape());
4041 
4042     const OpInfo::TensorProperties* output_properties;
4043     TF_RETURN_IF_ERROR(GetTensorProperties(
4044         strings::StrCat(node->name(), ":", 0), &output_properties));
4045     PartialTensorShape output_shape(output_properties->shape());
4046     NodeDef* output =
4047         AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
4048     if (!must_expand_dims) {
4049       output->set_op("Identity");
4050       output->set_device(node->device());
4051       SetDataTypeToAttr(output_properties->dtype(), "T", output);
4052       output->add_input(input_slice);
4053     } else {
4054       NodeDef* axis = AddEmptyNode(
4055           OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
4056       axis->set_op("Const");
4057       axis->set_device(node->device());
4058       // We need to add a control edge from input slice to guarantee that axis
4059       // constant will be executed in the same frame as `input_slice`, otherwise
4060       // ExpandDims might have mismatched input frames.
4061       axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
4062       auto axis_attr = axis->mutable_attr();
4063       SetDataTypeToAttr(DT_INT32, "dtype", axis);
4064       auto* axis_t = (*axis_attr)["value"].mutable_tensor();
4065       axis_t->set_dtype(DT_INT32);
4066       axis_t->add_int_val(pack_axis);
4067       AddToOptimizationQueue(axis);
4068       output->set_op("ExpandDims");
4069       output->set_device(node->device());
4070       SetDataTypeToAttr(output_properties->dtype(), "T", output);
4071       SetDataTypeToAttr(DT_INT32, "Tdim", output);
4072       output->add_input(input_slice);
4073       output->add_input(axis->name());
4074     }
4075 
4076     // Copy dependencies over.
4077     ForwardControlDependencies(output, {node, pack});
4078     AddToOptimizationQueue(output);
4079     *simplified_node_name = output->name();
4080 
4081     return OkStatus();
4082   }
4083 };
4084 
4085 // Eliminates unnecessary copies during sparse embedding lookup operations.
4086 //
4087 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
4088 // generates code of the form:
4089 //
4090 //     embeddings = <a 2D Tensor>
4091 //     sparse_ids = <a tf.int64 SparseTensor>
4092 //     segment_ids = sparse_ids.indices[:, 0]
4093 //     ids, idx = tf.unique(sparse_ids.values)
4094 //     gathered_rows = tf.gather(params, ids)
4095 //     result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
4096 //
4097 // In this case, all of the work in `tf.unique()` and `tf.gather()`
4098 // can be avoided by passing the full embeddings to
4099 // `tf.sparse.segment_<combiner>()` and performing the same amount of
4100 // computation (but fewer copies and allocations) as follows:
4101 //
4102 //     embeddings = <a 2D Tensor>
4103 //     sparse_ids = <a tf.int64 SparseTensor>
4104 //     segment_ids = sparse_ids.indices[:, 0]
4105 //     result = tf.sparse.segment_<combiner>(
4106 //          embeddings, sparse_ids.values, segment_ids)
4107 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
4108  public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4109   explicit SimplifyEmbeddingLookupStage(
4110       const GraphOptimizerContext& ctx,
4111       const ArithmeticOptimizerContext& ctx_ext)
4112       : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
4113   }
4114   ~SimplifyEmbeddingLookupStage() override = default;
4115 
IsSupported(const NodeDef * node) const4116   bool IsSupported(const NodeDef* node) const override {
4117     return IsAnySparseSegmentReduction(*node);
4118   }
4119 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4120   Status TrySimplify(NodeDef* reduction_node,
4121                      string* simplified_node_name) override {
4122     if (IsInPreserveSet(*reduction_node)) return OkStatus();
4123 
4124     // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
4125     // axis.
4126     NodeDef* gather_node = nullptr;
4127     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
4128     if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
4129         gather_node->device() != reduction_node->device())
4130       return OkStatus();
4131     if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
4132       return OkStatus();
4133 
4134     // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
4135     // axis.
4136     NodeDef* unique_node = nullptr;
4137     TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
4138     if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
4139         unique_node->device() != gather_node->device())
4140       return OkStatus();
4141     if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
4142       return OkStatus();
4143 
4144     DataType unique_element_type;
4145     TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
4146 
4147     // Input 1 (indices) of the reduction node must be output 1 of the unique
4148     // node.
4149     const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
4150     if (idx_tensor != TensorId(unique_node->name(), 1)) return OkStatus();
4151 
4152     // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
4153     // node.
4154     reduction_node->set_input(1, unique_node->input(0));
4155     ctx().node_map->UpdateInput(reduction_node->name(),
4156                                 reduction_node->input(1),
4157                                 unique_node->input(0));
4158     SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
4159 
4160     // Input 0 (data) of the reduction node becomes input 1 (params) of the
4161     // gather node.
4162     const OpInfo::TensorProperties* gather_input_properties;
4163     TF_RETURN_IF_ERROR(
4164         GetTensorProperties(gather_node->input(0), &gather_input_properties));
4165     if (gather_input_properties->dtype() == DT_RESOURCE) {
4166       // If the input is a ResourceGather, we need to add
4167       // ReadVariableOp.
4168       NodeDef* variable_node = nullptr;
4169       TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(0), &variable_node));
4170       NodeDef* read_var_node = ctx().optimized_graph->add_node();
4171       read_var_node->set_name(OptimizedNodeName(
4172           ParseNodeScopeAndName(reduction_node->name()), "ReadVar"));
4173       read_var_node->set_op("ReadVariableOp");
4174       read_var_node->add_input(gather_node->input(0));
4175       read_var_node->set_device(variable_node->device());
4176 
4177       // The Variable and the Gather node should have the same
4178       // dtype, but it might not be set on both nodes.
4179       auto attr = read_var_node->mutable_attr();
4180       if (variable_node->attr().count("dtype")) {
4181         SetAttrValue(variable_node->attr().at("dtype").type(),
4182                      &(*attr)["dtype"]);
4183       }
4184       if (gather_node->attr().count("dtype")) {
4185         SetAttrValue(gather_node->attr().at("dtype").type(), &(*attr)["dtype"]);
4186       }
4187       // Copy the _class attr from the Gather node should it exist in case
4188       // of location constraints with the variable.
4189       if (gather_node->attr().count("_class")) {
4190         (*attr)["_class"] = gather_node->attr().at("_class");
4191       }
4192       if (variable_node->attr().count("shape")) {
4193         SetAttrValue(variable_node->attr().at("shape").shape(),
4194                      &(*attr)["_output_shapes"]);
4195       }
4196 
4197       ctx().node_map->AddNode(read_var_node->name(), read_var_node);
4198       reduction_node->set_input(0, read_var_node->name());
4199       ctx().node_map->UpdateInput(reduction_node->name(),
4200                                   reduction_node->input(0),
4201                                   read_var_node->name());
4202     } else {
4203       reduction_node->set_input(0, gather_node->input(0));
4204       ctx().node_map->UpdateInput(reduction_node->name(),
4205                                   reduction_node->input(0),
4206                                   gather_node->input(0));
4207     }
4208     *simplified_node_name = reduction_node->name();
4209     return OkStatus();
4210   }
4211 
4212  private:
IsAxis0(const NodeDef & node,int axis_input)4213   bool IsAxis0(const NodeDef& node, int axis_input) {
4214     Tensor axis_tensor;
4215     if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
4216       return false;
4217     if (axis_tensor.NumElements() != 1) return false;
4218     if (axis_tensor.dtype() == DT_INT32) {
4219       return axis_tensor.flat<int32>()(0) == 0;
4220     } else if (axis_tensor.dtype() == DT_INT64) {
4221       return axis_tensor.flat<int64_t>()(0) == 0;
4222     } else {
4223       return false;
4224     }
4225   }
4226 };
4227 
4228 // Eliminates unnecessary casts before sparse segment reduction operations.
4229 //
4230 // Existing graphs and library code would often insert a cast from DT_INT64 to
4231 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
4232 // Support for DT_INT64 indices and/or segment_ids now exists, so we can
4233 // pass the input directly without a cast.
4234 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
4235  public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4236   explicit RemoveCastIntoSegmentReductionStage(
4237       const GraphOptimizerContext& ctx,
4238       const ArithmeticOptimizerContext& ctx_ext)
4239       : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
4240                                  ctx_ext) {}
4241   ~RemoveCastIntoSegmentReductionStage() override = default;
4242 
IsSupported(const NodeDef * node) const4243   bool IsSupported(const NodeDef* node) const override {
4244     return IsAnySparseSegmentReduction(*node);
4245   }
4246 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4247   Status TrySimplify(NodeDef* reduction_node,
4248                      string* simplified_node_name) override {
4249     if (IsInPreserveSet(*reduction_node)) return OkStatus();
4250 
4251     bool optimized = false;
4252 
4253     // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
4254     // DT_INT64.
4255     std::array<std::pair<int, string>, 2> input_details = {
4256         std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
4257 
4258     for (const auto& input : input_details) {
4259       int input_index = input.first;
4260       const string& type_attr_name = input.second;
4261       NodeDef* cast_node = nullptr;
4262       TF_RETURN_IF_ERROR(
4263           GetInputNode(reduction_node->input(input_index), &cast_node));
4264       DataType original_index_type;
4265       if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
4266         reduction_node->set_input(input_index, cast_node->input(0));
4267         ctx().node_map->UpdateInput(reduction_node->name(),
4268                                     reduction_node->input(1),
4269                                     cast_node->input(0));
4270         SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
4271         optimized = true;
4272       }
4273     }
4274 
4275     if (optimized) *simplified_node_name = reduction_node->name();
4276     return OkStatus();
4277   }
4278 
4279  private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)4280   bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
4281     if (!IsCast(node)) return false;
4282     if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
4283     return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
4284   }
4285 };
4286 
4287 }  // namespace
4288 
SimplifyArithmeticOps(bool can_use_shapes)4289 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
4290   SetVector<NodeDef*> nodes_to_simplify;
4291   nodes_to_simplify.Reserve(optimized_graph_->node_size());
4292   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
4293     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
4294   }
4295 
4296   const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
4297                                   graph_properties_.get(), node_map_.get(),
4298                                   &feed_nodes_, opt_level_);
4299   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
4300 
4301   // Stop pipeline after first stage returning non-empty simplified tensor
4302   // name.
4303   const auto stop = [](const string& result) { return !result.empty(); };
4304   GraphOptimizerStagePipeline<string> pipeline(stop);
4305   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
4306 
4307   if (options_.combine_add_to_addn && can_use_shapes)
4308     pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
4309   if (options_.fold_conjugate_into_transpose)
4310     pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
4311   if (options_.fold_multiply_into_conv)
4312     pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
4313   if (options_.fold_transpose_into_matmul)
4314     pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
4315   if (is_aggressive && options_.hoist_common_factor_out_of_aggregation &&
4316       can_use_shapes)
4317     pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
4318   if (options_.minimize_broadcasts && can_use_shapes)
4319     pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
4320   if (options_.remove_identity_transpose && can_use_shapes)
4321     pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
4322   if (options_.remove_involution)
4323     pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
4324   if (options_.remove_redundant_bitcast)
4325     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
4326   if (options_.remove_redundant_cast)
4327     pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
4328   if (options_.replace_pack_with_tile_reshape)
4329     pipeline.AddStage<ReplacePackWithTileReshape>(ctx, ctx_ext);
4330   if (options_.replace_mul_with_tile && can_use_shapes)
4331     pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
4332   if (options_.reduce_upsampling_dims)
4333     pipeline.AddStage<ReduceUpsamplingDims>(ctx, ctx_ext);
4334   if (options_.remove_redundant_reshape && can_use_shapes)
4335     pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
4336   if (options_.remove_negation)
4337     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
4338   if (options_.replace_mul_with_square)
4339     pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
4340   if (options_.remove_logical_not)
4341     pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
4342   if (options_.reorder_cast_like_and_value_preserving)
4343     pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
4344   if (options_.simplify_aggregation)
4345     pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
4346   if (options_.hoist_cwise_unary_chains)
4347     pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
4348   if (options_.convert_sqrt_div_to_rsqrt_mul)
4349     pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
4350   if (options_.remove_idempotent)
4351     pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
4352   if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
4353   if (options_.convert_log1p)
4354     pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
4355   if (options_.convert_log_softmax)
4356     pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
4357   if (options_.optimize_max_or_min_of_monotonic)
4358     pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
4359   if (options_.convert_expm1)
4360     pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
4361   if (options_.unary_ops_composition)
4362     pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
4363   if (options_.remove_stack_slice_same_axis)
4364     pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
4365   if (options_.simplify_embedding_lookup)
4366     pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
4367   if (options_.remove_cast_into_segment_reduction)
4368     pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
4369   if (options_.fuse_squared_diff)
4370     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
4371 
4372   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
4373           << absl::StrJoin(pipeline.StageNames(), ", ");
4374 
4375   while (!nodes_to_simplify.Empty()) {
4376     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4377     NodeDef* node = nodes_to_simplify.PopBack();
4378 
4379     string simplified_tensor = "";
4380     bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
4381 
4382     // If the node was not optimized by any of the stages, go to the next one.
4383     if (!optimized) continue;
4384 
4385     // re-wire consumers of an old node to the new one
4386     if (NodeName(simplified_tensor) != node->name()) {
4387       // Always consider simplified_tensor for further optimizations.
4388       NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
4389       if (simplified_node != nullptr) {
4390         nodes_to_simplify.PushBack(simplified_node);
4391       }
4392       // When `node` is simplified to another node rather than in-place, the
4393       // consumers of `node` are already redirected to `simplified_tensor`.
4394       // Re-push the consumers into `nodes_to_simplify` for further
4395       // optimizations.
4396       const std::vector<NodeDef*> consumers =
4397           node_map_->GetOutputsOrderedByNodeName(node->name());
4398       for (NodeDef* consumer : consumers) {
4399         // Update `consumer`'s use of `node` to `input`'s operand.
4400         for (int i = 0; i < consumer->input_size(); ++i) {
4401           int operand_pos;
4402           string operand_node_name =
4403               ParseNodeName(consumer->input(i), &operand_pos);
4404           if (operand_node_name == node->name()) {
4405             *consumer->mutable_input(i) =
4406                 (operand_pos < 0
4407                      ? AsControlDependency(NodeName(simplified_tensor))
4408                      : simplified_tensor);
4409           }
4410         }
4411         node_map_->UpdateInput(consumer->name(), node->name(),
4412                                simplified_tensor);
4413         nodes_to_simplify.PushBack(consumer);
4414       }
4415     }
4416   }
4417   return OkStatus();
4418 }
4419 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)4420 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
4421                                      const GrapplerItem& item,
4422                                      GraphDef* optimized_graph) {
4423   // Set up helper data structures.
4424   nodes_to_preserve_ = item.NodesToPreserve();
4425   fetch_nodes_known_ = !item.fetch.empty();
4426   GrapplerItem optimized_item(item);
4427   optimized_graph_ = &optimized_item.graph;
4428 
4429   node_map_.reset(new NodeMap(optimized_graph_));
4430   for (const auto& feed : item.feed) {
4431     feed_nodes_.insert(NodeName(feed.first));
4432   }
4433 
4434   // // Disable restricted graph rewrites.
4435   options_.unary_ops_composition &=
4436       item.optimization_options().allow_non_differentiable_rewrites;
4437 
4438   // Perform topological sort on the graph in order to help DedupComputations
4439   // and AddOpsRewrite to optimize larger subgraphs starting from the roots
4440   // with more inputs.
4441   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
4442   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4443 
4444   graph_properties_.reset(new GraphProperties(optimized_item));
4445   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4446   const Status status =
4447       graph_properties_->InferStatically(assume_valid_feeds,
4448                                          /*aggressive_shape_inference=*/false,
4449                                          /*include_tensor_values=*/false);
4450   const bool can_use_shapes = status.ok();
4451   if (!can_use_shapes) {
4452     VLOG(1) << "Shape inference failed." << status.error_message();
4453   }
4454 
4455   // Perform the optimizations.
4456   TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
4457   *optimized_graph = std::move(*optimized_graph_);
4458   return OkStatus();
4459 }
4460 
4461 }  // namespace grappler
4462 }  // namespace tensorflow
4463