• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
17 
18 #include <algorithm>
19 #include <deque>
20 #include <limits>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/strings/str_join.h"
28 #include "tensorflow/core/framework/attr_value.pb.h"
29 #include "tensorflow/core/framework/attr_value_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/node_def_util.h"
32 #include "tensorflow/core/framework/op.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/framework/tensor_shape.pb.h"
35 #include "tensorflow/core/framework/types.h"
36 #include "tensorflow/core/framework/types.pb.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/graph_topology_view.h"
39 #include "tensorflow/core/grappler/grappler_item.h"
40 #include "tensorflow/core/grappler/op_types.h"
41 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
42 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
43 #include "tensorflow/core/grappler/utils.h"
44 #include "tensorflow/core/grappler/utils/canonicalizer.h"
45 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
46 #include "tensorflow/core/grappler/utils/topological_sort.h"
47 #include "tensorflow/core/grappler/utils/traversal.h"
48 #include "tensorflow/core/lib/core/errors.h"
49 #include "tensorflow/core/lib/core/stringpiece.h"
50 #include "tensorflow/core/lib/hash/hash.h"
51 #include "tensorflow/core/lib/strings/str_util.h"
52 #include "tensorflow/core/lib/strings/strcat.h"
53 #include "tensorflow/core/platform/errors.h"
54 #include "tensorflow/core/platform/macros.h"
55 #include "tensorflow/core/platform/tensor_coding.h"
56 #include "tensorflow/core/protobuf/error_codes.pb.h"
57 #include "tensorflow/core/util/device_name_utils.h"
58 #include "tensorflow/core/util/saved_tensor_slice_util.h"
59 #include "tensorflow/core/util/strided_slice_op.h"
60 
61 using tensorflow::strings::StrCat;
62 
63 namespace tensorflow {
64 namespace grappler {
65 namespace {
66 
67 // Mark nodes created or optimized by a stage with a tag.
68 constexpr char kAddOpsRewriteTag[] =
69     "_grappler_ArithmeticOptimizer_AddOpsRewriteStage";
70 constexpr char kMinimizeBroadcastsTag[] =
71     "_grappler_ArithmeticOptimizer_MinimizeBroadcasts";
72 
73 // Extract values from a Const op to `values`. Returns true if succeeds.
74 template <typename T>
ValuesFromConstNode(const NodeDef & node,std::vector<T> * values)75 bool ValuesFromConstNode(const NodeDef& node, std::vector<T>* values) {
76   if (node.op() != "Const") {
77     return false;
78   }
79 
80   if (node.attr().count("dtype") == 0 || node.attr().count("value") == 0 ||
81       node.attr().at("dtype").type() != DataTypeToEnum<T>::value) {
82     return false;
83   }
84 
85   // TensorProto represents the content of the tensor in either <type>_val or
86   // tensor_content.
87   const TensorProto& tensor = node.attr().at("value").tensor();
88   typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
89       checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
90 
91   if (!tensor_values->empty() && tensor.has_tensor_shape()) {
92     // When tensor_shape is set, theoretically the representation of the data
93     // could be compressed. So, before copying values to the returned vector,
94     // make sure no compression happens.
95     const TensorShapeProto& shape = tensor.tensor_shape();
96     if (shape.dim_size() == 1 && shape.dim(0).size() == tensor_values->size()) {
97       values->insert(values->end(), tensor_values->begin(),
98                      tensor_values->end());
99       return true;
100     }
101   }
102 
103   const auto tensor_content_size = tensor.tensor_content().size();
104   if (tensor_content_size > 0) {
105     CHECK_EQ(0, tensor_content_size % sizeof(T))
106         << "tensor_content_size (" << tensor_content_size
107         << ") is not a multiple of " << sizeof(T);
108     values->resize(tensor_content_size / sizeof(T));
109     port::CopyToArray(tensor.tensor_content(),
110                       reinterpret_cast<char*>(values->data()));
111     return true;
112   }
113 
114   return false;
115 }
116 
MaybeAddControlInput(const string & new_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)117 bool MaybeAddControlInput(const string& new_input, NodeDef* node,
118                           GraphDef* graph, NodeMap* node_map) {
119   bool already_exists = false;
120   for (const string& input : node->input()) {
121     if (input == new_input || AsControlDependency(input) == new_input) {
122       already_exists = true;
123       break;
124     }
125   }
126   if (!already_exists) {
127     const string ctrl_dep =
128         ConstantFolding::AddControlDependency(new_input, graph, node_map);
129     node->add_input(ctrl_dep);
130     node_map->AddOutput(NodeName(new_input), node->name());
131   }
132   return !already_exists;
133 }
134 
SetDataTypeToAttr(DataType dtype,const string & attr_name,NodeDef * node)135 void SetDataTypeToAttr(DataType dtype, const string& attr_name, NodeDef* node) {
136   (*node->mutable_attr())[attr_name].set_type(dtype);
137 }
138 
GetTailOfValuePreservingChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)139 NodeDef* GetTailOfValuePreservingChain(
140     const NodeDef& node, const NodeMap& node_map,
141     const std::unordered_set<string>& nodes_to_preserve) {
142   auto is_value_preserving_non_branching = [&](const NodeDef& node) {
143     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
144            IsValuePreserving(node) && NumNonControlOutputs(node, node_map) == 1;
145   };
146   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
147                         is_value_preserving_non_branching);
148 }
149 
GetTailOfIdempotentChain(const NodeDef & node,const NodeMap & node_map,const std::unordered_set<string> & nodes_to_preserve)150 NodeDef* GetTailOfIdempotentChain(
151     const NodeDef& node, const NodeMap& node_map,
152     const std::unordered_set<string>& nodes_to_preserve) {
153   auto is_idempotent_non_branching = [&](const NodeDef& node) {
154     return nodes_to_preserve.find(node.name()) == nodes_to_preserve.end() &&
155            IsIdempotent(node) && NumNonControlOutputs(node, node_map) == 1;
156   };
157   return GetTailOfChain(node, node_map, /*follow_control_input=*/false,
158                         is_idempotent_non_branching);
159 }
160 
161 // GetElementUnexhaustive tries to get the value of an element in a tensor and
162 // turn it into complex128 type. It only check for a limited number of data
163 // types, so it's unexhaustive.
GetElementUnexhaustive(const Tensor & t,int i,const std::set<int> & dtypes,complex128 * element)164 bool GetElementUnexhaustive(const Tensor& t, int i, const std::set<int>& dtypes,
165                             complex128* element) {
166   if (dtypes.find(t.dtype()) == dtypes.end()) return false;
167   switch (t.dtype()) {
168     case DT_BFLOAT16:
169       *element = complex128(t.flat<bfloat16>()(i));
170       return true;
171     case DT_HALF:
172       *element = complex128(static_cast<double>(t.flat<Eigen::half>()(i)), 0);
173       return true;
174     case DT_INT32:
175       *element = complex128(t.flat<int32>()(i));
176       return true;
177     case DT_INT64:
178       *element = complex128(t.flat<int64>()(i));
179       return true;
180     case DT_FLOAT:
181       *element = complex128(t.flat<float>()(i));
182       return true;
183     case DT_DOUBLE:
184       *element = complex128(t.flat<double>()(i));
185       return true;
186     case DT_COMPLEX64:
187       *element = complex128(t.flat<complex64>()(i));
188       return true;
189     case DT_COMPLEX128:
190       *element = t.flat<complex128>()(i);
191       return true;
192     default:
193       return false;
194   }
195 }
196 
NodeIsOnCpu(const NodeDef & node)197 bool NodeIsOnCpu(const NodeDef& node) {
198   string task;
199   string device;
200   return DeviceNameUtils::SplitDeviceName(node.device(), &task, &device) &&
201          absl::StrContains(device, DEVICE_CPU);
202 }
203 
204 // True if all regular (non-control) inputs reference the same node or if there
205 // are no non-control inputs
AllRegularInputsEqual(const NodeDef & node)206 bool AllRegularInputsEqual(const NodeDef& node) {
207   if (!HasRegularInputs(node)) return true;
208   for (int i = 1; i < node.input_size(); ++i) {
209     if (IsControlInput(node.input(i))) {
210       break;
211     }
212     if (node.input(0) != node.input(i)) {
213       return false;
214     }
215   }
216   return true;
217 }
218 
219 // Replace a node with NoOp and reset shape inference results for it..
ReplaceWithNoOp(NodeDef * node,const GraphOptimizerContext & ctx)220 void ReplaceWithNoOp(NodeDef* node, const GraphOptimizerContext& ctx) {
221   ctx.node_map->RemoveInputs(node->name());
222   ctx.graph_properties->ClearInputProperties(node->name());
223   ctx.graph_properties->ClearOutputProperties(node->name());
224   EraseRegularNodeAttributes(node);
225   node->set_op("NoOp");
226   node->clear_input();
227 }
228 
229 // Graph optimizer context extension specific to ArithmeticOptimizer.
230 struct ArithmeticOptimizerContext {
ArithmeticOptimizerContexttensorflow::grappler::__anon912b57ca0111::ArithmeticOptimizerContext231   explicit ArithmeticOptimizerContext(SetVector<NodeDef*>* nodes_to_simplify)
232       : nodes_to_simplify(nodes_to_simplify) {}
233   SetVector<NodeDef*>* nodes_to_simplify;
234 };
235 
236 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
237 // AddOps optimization, etc...
238 class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
239  public:
ArithmeticOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)240   explicit ArithmeticOptimizerStage(const string& name,
241                                     const GraphOptimizerContext& ctx,
242                                     const ArithmeticOptimizerContext ctx_ext)
243       : GraphOptimizerStage("ArithmeticOptimizer", name, ctx),
244         ctx_ext_(ctx_ext) {}
245   ~ArithmeticOptimizerStage() override = default;
246 
247  protected:
248   // Simplification graph rewrite can create additional nodes that are inputs
249   // to final simplified node, they can be also added to the arithmetic
250   // optimizer queue for further optimization.
AddToOptimizationQueue(NodeDef * node)251   void AddToOptimizationQueue(NodeDef* node) {
252     ctx_ext_.nodes_to_simplify->PushBack(node);
253   }
254 
255   // Update consumers of node to take new_input as input instead.
UpdateConsumers(NodeDef * node,const string & new_input)256   Status UpdateConsumers(NodeDef* node, const string& new_input) {
257     const auto consumers = ctx().node_map->GetOutputs(node->name());
258     if (consumers.empty()) return Status::OK();
259     const TensorId new_tensor = ParseTensorName(new_input);
260     for (NodeDef* consumer : consumers) {
261       if (consumer->name() == new_tensor.node()) continue;
262       bool updated = false;
263       for (int i = 0; i < consumer->input_size(); ++i) {
264         const TensorId input_tensor = ParseTensorName(consumer->input(i));
265         if (input_tensor.node() == node->name()) {
266           if (new_tensor.index() < 0 && input_tensor.index() >= 0) {
267             // Overwriting a data input with a control input will make the graph
268             // invalid.
269             return errors::InvalidArgument(
270                 "Cannot override data input ", input_tensor.ToString(),
271                 " with control input ", new_tensor.ToString());
272           }
273           consumer->set_input(i, input_tensor.index() < 0
274                                      ? absl::StrCat("^", new_tensor.node())
275                                      : new_input);
276           ctx().node_map->UpdateInput(consumer->name(), node->name(),
277                                       new_input);
278           updated = true;
279         }
280       }
281       if (updated) {
282         DedupControlInputs(consumer);
283         AddToOptimizationQueue(consumer);
284       }
285     }
286     return Status::OK();
287   }
288 
289   // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
290   // optimizations will be migrated to stages
ForwardControlDependencies(NodeDef * target_node,const std::vector<const NodeDef * > & src_nodes)291   void ForwardControlDependencies(
292       NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
293     for (const auto& src : src_nodes) {
294       for (int i = src->input_size() - 1; i >= 0; --i) {
295         if (IsControlInput(src->input(i))) {
296           *target_node->add_input() = src->input(i);
297           ctx().node_map->AddOutput(NodeName(src->input(i)),
298                                     target_node->name());
299         } else {
300           break;
301         }
302       }
303     }
304     DedupControlInputs(target_node);
305   }
306 
IsReallyConstant(const NodeDef & node) const307   bool IsReallyConstant(const NodeDef& node) const {
308     if (!IsConstant(node)) {
309       return false;
310     }
311     // If the node is fed it's not constant anymore.
312     return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
313   }
314 
IsInPreserveSet(const NodeDef & node) const315   bool IsInPreserveSet(const NodeDef& node) const {
316     return ctx().nodes_to_preserve->find(node.name()) !=
317            ctx().nodes_to_preserve->end();
318   }
319 
320   // TODO(ezhulenev): move to GraphOptimizerStage?
IsDrivenByControlDependency(const NodeDef & node) const321   bool IsDrivenByControlDependency(const NodeDef& node) const {
322     return std::any_of(
323         node.input().begin(), node.input().end(),
324         [](const string& input) { return IsControlInput(input); });
325   }
326 
327   // TODO(ezhulenev): move to GraphOptimizerStage?
DrivesControlDependency(const NodeDef & node) const328   bool DrivesControlDependency(const NodeDef& node) const {
329     for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
330       for (int i = 0; i < output->input_size(); ++i) {
331         const TensorId tensor = ParseTensorName(output->input(i));
332         if (tensor.node() == node.name() && tensor.index() < 0) {
333           return true;
334         }
335       }
336     }
337     return false;
338   }
339 
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)340   bool GetTensorFromConstNode(const string& node_name_or_input,
341                               Tensor* tensor) {
342     const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
343     return node != nullptr && IsReallyConstant(*node) &&
344            CheckAttrExists(*node, "value").ok() &&
345            tensor->FromProto(node->attr().at("value").tensor());
346   }
347 
348  private:
349   // Extended context required for ArithmeticOptimizer.
350   const ArithmeticOptimizerContext ctx_ext_;
351 };
352 
353 // Subtype of ArithmeticOptimizerStage that does optimization by rewriting a
354 // group of nodes from the optimized graph.
355 //
356 // * AddOpsRewrite:
357 //   Rewrite a group of Add/AddN with compact Add/AddN tree
358 //
359 // * MinimizeBroadcasts:
360 //   Rewrite a group of binary associative ops, reordering
361 //   inputs, to minimize the cost of broadcast
362 class ArithmeticNodesGroupOptimizerStage : public ArithmeticOptimizerStage {
363  public:
ArithmeticNodesGroupOptimizerStage(const string & name,const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext ctx_ext)364   explicit ArithmeticNodesGroupOptimizerStage(
365       const string& name, const GraphOptimizerContext& ctx,
366       const ArithmeticOptimizerContext ctx_ext)
367       : ArithmeticOptimizerStage(name, ctx, ctx_ext) {}
368   ~ArithmeticNodesGroupOptimizerStage() override = default;
369 
370   // Input name with a statically inferred shape from GraphProperties
371   struct InputAndShape {
InputAndShapetensorflow::grappler::__anon912b57ca0111::ArithmeticNodesGroupOptimizerStage::InputAndShape372     InputAndShape(const string& input, const TensorShapeProto& shape)
373         : input(input), shape(shape) {}
374     string input;
375     TensorShapeProto shape;
376   };
377 
378   // Subgraph (subtree) of nodes, that we want to optimize in "one shot" (e.g.
379   // all the Add nodes that we plan to rewrite with a single AddN). Subgraph is
380   // obtained by graph traversal, starting from a root node.
381   struct OptimizedNodesGroup {
382     NodeDef* root_node;
383     TensorShapeProto root_shape;
384     // Optimized nodes that will be updated or removed by rewrite
385     std::vector<NodeDef*> optimized_nodes;
386     // Inputs to optimized nodes
387     std::vector<InputAndShape> inputs;
388   };
389 
TrySimplify(NodeDef * node,string * simplified_node_name)390   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
391     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
392 
393     OptimizedNodesGroup group;
394     TF_RETURN_IF_ERROR(CreateOptimizedNodesGroup(node, &group));
395 
396     if (!group.optimized_nodes.empty()) {
397       *simplified_node_name = RewriteOptimizedNodesGroup(group);
398     }
399 
400     return Status::OK();
401   }
402 
403  protected:
404   // Modify the optimized graph after nodes group was successfully identified
405   virtual string RewriteOptimizedNodesGroup(
406       const OptimizedNodesGroup& group) = 0;
407 
408   // Check if input can become a part of current optimized nodes group.
409   virtual bool IsAbsorbableByOptimizedNodesGroup(
410       const OptimizedNodesGroup& group, const NodeDef& node) const = 0;
411 
AbsorbInputByOptimizedNodesGroup(const string & input,OptimizedNodesGroup * group) const412   Status AbsorbInputByOptimizedNodesGroup(const string& input,
413                                           OptimizedNodesGroup* group) const {
414     std::deque<const string*> input_tensors;
415     input_tensors.push_front(&input);
416 
417     while (!input_tensors.empty()) {
418       const string* input_tensor = input_tensors.front();
419       input_tensors.pop_front();
420 
421       // Get a node for the input tensor.
422       NodeDef* input_node;
423       TF_RETURN_IF_ERROR(GetInputNode(*input_tensor, &input_node));
424 
425       if (IsAbsorbableByOptimizedNodesGroup(*group, *input_node)) {
426         group->optimized_nodes.push_back(input_node);
427         for (int i = input_node->input_size() - 1; i >= 0; --i) {
428           const string& absorbed_node_input = input_node->input(i);
429           // TODO(ezhulenev): support control inputs
430           if (IsControlInput(absorbed_node_input)) continue;
431           input_tensors.push_front(&absorbed_node_input);
432         }
433       } else {
434         // If input node can't be absorbed, add it to OptimizedNodesGroup input.
435         const OpInfo::TensorProperties* properties;
436         TF_RETURN_IF_ERROR(GetTensorProperties(*input_tensor, &properties));
437         group->inputs.emplace_back(*input_tensor, properties->shape());
438       }
439     }
440 
441     return Status::OK();
442   }
443 
CreateOptimizedNodesGroup(NodeDef * root_node,OptimizedNodesGroup * group) const444   Status CreateOptimizedNodesGroup(NodeDef* root_node,
445                                    OptimizedNodesGroup* group) const {
446     const OpInfo::TensorProperties* root_node_output_properties;
447     TF_RETURN_IF_ERROR(
448         GetTensorProperties(root_node->name(), &root_node_output_properties));
449 
450     group->root_node = root_node;
451     group->root_shape = root_node_output_properties->shape();
452 
453     group->optimized_nodes.reserve(root_node->input_size());
454     for (int i = 0; i < root_node->input_size(); ++i) {
455       const string& input_i = root_node->input(i);
456       // TODO(ezhulenev): add support for control inputs
457       if (IsControlInput(input_i)) continue;
458       TF_RETURN_IF_ERROR(AbsorbInputByOptimizedNodesGroup(input_i, group));
459     }
460 
461     return Status::OK();
462   }
463 
464   // Check if all inputs can be broadcasted to the same shape
465   // TODO(ezhulenev): move to GraphOptimizerStage?
HasAllInputsBroadcastableToShape(const NodeDef & node,const OpInfo::TensorProperties & properties) const466   bool HasAllInputsBroadcastableToShape(
467       const NodeDef& node, const OpInfo::TensorProperties& properties) const {
468     auto is_broadcastable = [this, &properties](const string& input) {
469       const OpInfo::TensorProperties* input_props;
470       Status has_input_properties = GetTensorProperties(input, &input_props);
471       return has_input_properties.ok() &&
472              ShapesBroadcastable(properties, *input_props);
473     };
474     return std::all_of(node.input().begin(), node.input().end(),
475                        is_broadcastable);
476   }
477 
ShapeSignature(const TensorShapeProto & shape) const478   string ShapeSignature(const TensorShapeProto& shape) const {
479     string signature = strings::StrCat("rank:", shape.dim_size(), ":dim");
480     for (int i = 0; i < shape.dim_size(); ++i)
481       strings::StrAppend(&signature, ":", shape.dim(i).size());
482     return signature;
483   }
484 
MarkWithTag(const StringPiece tag,NodeDef * node)485   void MarkWithTag(const StringPiece tag, NodeDef* node) {
486     AddNodeAttr(tag, true, node);
487   }
488 
MarkAllMembersWithTag(const OptimizedNodesGroup & group,const StringPiece tag) const489   void MarkAllMembersWithTag(const OptimizedNodesGroup& group,
490                              const StringPiece tag) const {
491     AddNodeAttr(tag, true, group.root_node);
492     for (NodeDef* optimized_node : group.optimized_nodes) {
493       AddNodeAttr(tag, true, optimized_node);
494     }
495   }
496 
IsOnTheSameDevice(const OptimizedNodesGroup & group,const NodeDef & node) const497   bool IsOnTheSameDevice(const OptimizedNodesGroup& group,
498                          const NodeDef& node) const {
499     return group.root_node->device() == node.device();
500   }
501 
IsInPreserveSet(const NodeDef & node) const502   bool IsInPreserveSet(const NodeDef& node) const {
503     return ctx().nodes_to_preserve->find(node.name()) !=
504            ctx().nodes_to_preserve->end();
505   }
506 
IsMarkedWithTag(const NodeDef & node,const StringPiece tag) const507   bool IsMarkedWithTag(const NodeDef& node, const StringPiece tag) const {
508     return HasNodeAttr(node, tag);
509   }
510 
IsMarkedWithAnyTag(const NodeDef & node,const StringPiece tag1,const StringPiece tag2) const511   bool IsMarkedWithAnyTag(const NodeDef& node, const StringPiece tag1,
512                           const StringPiece tag2) const {
513     return IsMarkedWithTag(node, tag1) || IsMarkedWithTag(node, tag2);
514   }
515 };
516 
517 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
518 // original inputs of absorbed nodes.
519 //
520 // 1) All nodes must have the same device placement.
521 //
522 // 2) If All nodes in a Add/AddN subgraph have symbolically equal shape, tree is
523 //    optimized to a single AddN node.
524 //
525 //                AddN_1
526 //             /    |    \
527 //          Add_1   z   Add_2       -> AddN(x, y, z, w, q, e)
528 //          /  \        /  \
529 //         x    y      w    Add_3
530 //                          / \
531 //                         q   e
532 //
533 // 3) If some nodes have different shape (it needs to be broadcastable to the
534 //    shape of a "root), tree is optimized to AddNs for symbolically equal
535 //    shapes, and a tree of Add ops, that minimize broadcasts.
536 //
537 //                AddN_1                                 Add
538 //             /    |    \                              /  \
539 //          Add_1   z   Add_2       ->               Add    w
540 //          /  \        /  \                        /   \
541 //         x    y      w    Add_3      AddN(x, y, q, e)  z
542 //                          / \
543 //                         q   e
544 class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
545  public:
AddOpsRewriteStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)546   explicit AddOpsRewriteStage(const GraphOptimizerContext& ctx,
547                               const ArithmeticOptimizerContext& ctx_ext)
548       : ArithmeticNodesGroupOptimizerStage("AddOpsRewrite", ctx, ctx_ext) {}
549   ~AddOpsRewriteStage() override = default;
550 
551   // Check if a node can become a root of AddOpsGroup
IsSupported(const NodeDef * node) const552   bool IsSupported(const NodeDef* node) const override {
553     if (!CanOptimize(*node)) return false;
554 
555     // shape must be symbolically defined and all inputs compatible with it
556     const OpInfo::TensorProperties* properties;
557     Status has_properties = GetTensorProperties(node->name(), &properties);
558     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
559            HasAllInputsBroadcastableToShape(*node, *properties);
560   }
561 
562  protected:
563   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const564   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
565                                          const NodeDef& node) const override {
566     if (!CanOptimize(node)) return false;
567 
568     if (!IsOnTheSameDevice(group, node)) {
569       return false;
570     }
571     // with a single output data consumer (presumably if we reach this node from
572     // previously absorbed or a root node, it means that this node is not used
573     // as an input to any other op, outside of the group)
574     if (NumNonControlDataOutputs(node, *ctx().node_map) != 1) {
575       return false;
576     }
577     // All input shapes must be broadcastable to the node shape
578     const OpInfo::TensorProperties* properties;
579     Status has_properties = GetTensorProperties(node.name(), &properties);
580     return has_properties.ok() &&
581            HasAllInputsBroadcastableToShape(node, *properties);
582   }
583 
584   // Node requirements both for a root node and an absorbed node
CanOptimize(const NodeDef & node) const585   bool CanOptimize(const NodeDef& node) const {
586     // TODO(ezhulenev): check if AccumulateNV2 can be supported too
587     if (!IsAdd(node) && !IsAddN(node)) {
588       return false;
589     }
590     if (IsInPreserveSet(node) || IsMarkedWithTag(node, kAddOpsRewriteTag)) {
591       return false;
592     }
593     // TODO(ezhulenev): relax this condition for root node
594     return !(IsDrivenByControlDependency(node) ||
595              DrivesControlDependency(node));
596   }
597 
598   // Rewrite a group of add ops into a single AddN if all input shapes are
599   // symbolically equal. If not, create AddN for equal shapes first, and then
600   // build an Add tree, minimizing the cost of broadcasts.
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)601   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
602     VLOG(2) << "Collapse Add/AddN: root=" << group.root_node->name()
603             << " op=" << group.root_node->op()
604             << " num_optimized_nodes=" << group.optimized_nodes.size()
605             << " num_inputs=" << group.inputs.size();
606 
607     // Do not optimize any of the nodes that are part of this group.
608     MarkAllMembersWithTag(group, kAddOpsRewriteTag);
609 
610     // All new nodes will be placed under the scope of a root node.
611     auto root_scope_and_name = ParseNodeScopeAndName(group.root_node->name());
612 
613     // Find what shapes are present in the inputs of absorbed nodes.
614     std::unordered_map<string, std::vector<InputAndShape>> shape_sig_to_inputs;
615     for (const auto& input : group.inputs) {
616       shape_sig_to_inputs[ShapeSignature(input.shape)].push_back(input);
617     }
618 
619     using SigKV = decltype(shape_sig_to_inputs)::value_type;
620     VLOG(3) << "Add/AddN group has " << shape_sig_to_inputs.size()
621             << " unique shapes: "
622             << absl::StrJoin(shape_sig_to_inputs, ", ",
623                              [](string* out, SigKV p) {
624                                strings::StrAppend(out, p.first);
625                              });
626 
627     // Collect all the shapes from representative elements.
628     std::vector<TensorShapeProto> shapes;
629     shapes.reserve(shape_sig_to_inputs.size());
630     for (const auto& el : shape_sig_to_inputs)
631       shapes.push_back(el.second[0].shape);
632 
633     // If all inputs have the same shape, rewrite whole group with a single AddN
634     if (shapes.size() == 1) {
635       string node_name = UniqueOptimizedNodeName(root_scope_and_name);
636       AddInputsOfSymbolicallyEqualShape(*group.root_node, node_name,
637                                         group.inputs);
638       return node_name;
639     }
640 
641     // For inputs of different shapes:
642     // 1. Rewrite inputs of the same shape using AddN (leaf nodes)
643     // 2. Build a tree of Add nodes, minimizing cost of broadcast
644     std::sort(shapes.begin(), shapes.end(),
645               [](const TensorShapeProto& left, const TensorShapeProto& right) {
646                 return CompareSymbolicallyShapedTensorSizes(left, right);
647               });
648 
649     // optimized name for leaf AddN nodes
650     auto leaf_node_name = [&root_scope_and_name, this](int i) {
651       return UniqueOptimizedNodeName(root_scope_and_name,
652                                      strings::StrCat("Leaf_", i));
653     };
654     // optimized name for internal nodes of a tree built up from AddN leaves
655     auto internal_node_name = [&root_scope_and_name, this](int i) {
656       return UniqueOptimizedNodeName(root_scope_and_name,
657                                      strings::StrCat("Internal_", i));
658     };
659 
660     // Add/AddN nodes that must be added to the tree
661     std::deque<InputAndShape> add_ops;
662 
663     // Prepare leaf AddN nodes for inputs of equal shape
664     for (int i = 0, end = shapes.size(); i < end; ++i) {
665       const auto node_name = leaf_node_name(i);
666       const auto& inputs = shape_sig_to_inputs[ShapeSignature(shapes[i])];
667       add_ops.push_back(AddInputsOfSymbolicallyEqualShape(*group.root_node,
668                                                           node_name, inputs));
669     }
670 
671     // Build up a tree of Add ops
672     int internal_nodes = 0;
673     do {
674       const InputAndShape lhs = add_ops.front();
675       add_ops.pop_front();
676       const InputAndShape rhs = add_ops.front();
677       add_ops.pop_front();
678       string name = add_ops.empty()
679                         ? UniqueOptimizedNodeName(root_scope_and_name)
680                         : internal_node_name(internal_nodes++);
681       InputAndShape add = AddAggregatedInputs(*group.root_node, name, lhs, rhs);
682       add_ops.push_front(add);
683     } while (add_ops.size() > 1);
684 
685     InputAndShape optimized_root_node = add_ops.front();
686     return optimized_root_node.input;
687   }
688 
689   // Add 'AddN' node to aggregate inputs of symbolically equal shape
AddInputsOfSymbolicallyEqualShape(const NodeDef & root_node,const string & node_name,const std::vector<InputAndShape> & inputs)690   InputAndShape AddInputsOfSymbolicallyEqualShape(
691       const NodeDef& root_node, const string& node_name,
692       const std::vector<InputAndShape>& inputs) {
693     CHECK(!inputs.empty()) << "Inputs must be non-empty";
694 
695     // Do not create redundant AddN nodes
696     if (inputs.size() == 1 || root_node.attr().count("T") == 0) {
697       return inputs[0];
698     }
699 
700     // get shape from representative element
701     auto shape = inputs[0].shape;
702 
703     // copy attributes from a root node
704     DataType dtype = root_node.attr().at("T").type();
705 
706     // add new AddN node
707     NodeDef* node = AddEmptyNode(node_name);
708     node->set_op("AddN");
709     node->set_device(root_node.device());
710     (*node->mutable_attr())["T"].set_type(dtype);
711     (*node->mutable_attr())["N"].set_i(inputs.size());
712 
713     for (const auto& inputAndShape : inputs) {
714       ctx().node_map->AddOutput(inputAndShape.input, node_name);
715       node->add_input(inputAndShape.input);
716     }
717 
718     MarkWithTag(kAddOpsRewriteTag, node);
719     return InputAndShape(node_name, shape);
720   }
721 
722   // Add a single 'Add' node to sum two inputs
AddAggregatedInputs(const NodeDef & root_node,const string & node_name,const InputAndShape & left,const InputAndShape & right)723   InputAndShape AddAggregatedInputs(const NodeDef& root_node,
724                                     const string& node_name,
725                                     const InputAndShape& left,
726                                     const InputAndShape& right) {
727     // copy attributes from a root node
728     DataType dtype = root_node.attr().at("T").type();
729 
730     // add new Add node
731     NodeDef* node = AddEmptyNode(node_name);
732     node->set_op((dtype == DT_STRING || dtype == DT_STRING_REF) ? "Add"
733                                                                 : "AddV2");
734     node->set_device(root_node.device());
735     (*node->mutable_attr())["T"].set_type(dtype);
736     node->add_input(left.input);
737     node->add_input(right.input);
738 
739     ctx().node_map->AddOutput(left.input, node_name);
740     ctx().node_map->AddOutput(right.input, node_name);
741 
742     MarkWithTag(kAddOpsRewriteTag, node);
743     return InputAndShape(
744         node_name, TensorShapeProto());  // shape is not important at this point
745   }
746 };
747 
748 // Use the distributive property of multiplication and division over addition,
749 // along with commutativity of the former, to hoist common factors/denominators
750 // out of aggregate nodes where ALL the inputs are Mul/Div nodes.
751 // This pattern occurs frequently in regularization terms for the gradients
752 // during training.
753 //
754 // For example, we can rewrite an expression of the form:
755 //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
756 // to the following:
757 //   Mul(x, AddN(y1, y2, y3, ... yn))
758 // For division, we can rewrite
759 //   AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
760 // to:
761 //   Div(AddN(y1, y2, y3, ... yn), x)
762 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
763  public:
HoistCommonFactorOutOfAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)764   explicit HoistCommonFactorOutOfAggregation(
765       const GraphOptimizerContext& ctx,
766       const ArithmeticOptimizerContext& ctx_ext)
767       : ArithmeticOptimizerStage("HoistCommonFactor", ctx, ctx_ext) {}
768   ~HoistCommonFactorOutOfAggregation() override = default;
769 
IsSupported(const NodeDef * node) const770   bool IsSupported(const NodeDef* node) const override {
771     return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
772            !IsRewritten(node);
773   }
774 
TrySimplify(NodeDef * node,string * simplified_node_name)775   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
776     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
777 
778     bool common_factor_is_denominator = false;
779     std::set<string> common_factors;
780     std::vector<string> ctrl_deps;
781     TF_RETURN_IF_ERROR(GetCommonFactors(
782         node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
783 
784     if (common_factors.size() == 1) {
785       const string& common_factor = *common_factors.begin();
786 
787       // Gather up the non-shared factors
788       bool shapes_match = true;
789       std::vector<string> unique_factors;
790       TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
791                                           common_factor_is_denominator,
792                                           &shapes_match, &unique_factors));
793 
794       if (shapes_match) {
795         NodeDef* input_0;
796         TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
797 
798         // Use a copy of the first node for the outer multiplication/division.
799         NodeDef* new_outer_node = AddCopyNode(
800             OuterNodeName(node, common_factor_is_denominator), input_0);
801         // And a copy of aggregation node as one of the inner operands
802         NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
803 
804         new_outer_node->set_device(node->device());
805         if (common_factor_is_denominator) {
806           new_outer_node->set_input(0, new_add_node->name());
807           new_outer_node->set_input(1, common_factor);
808         } else {
809           new_outer_node->set_input(0, common_factor);
810           new_outer_node->set_input(1, new_add_node->name());
811         }
812 
813         ctx().node_map->AddOutput(common_factor, new_outer_node->name());
814         ctx().node_map->AddOutput(new_add_node->name(), new_outer_node->name());
815 
816         // Hoist non-shared factors up into the new AddN node.
817         for (int i = 0, end = unique_factors.size(); i < end; ++i) {
818           const string& unique_factor_i = unique_factors[i];
819           new_add_node->set_input(i, unique_factor_i);
820           ctx().node_map->AddOutput(unique_factor_i, new_add_node->name());
821         }
822 
823         // Add control deps on add node
824         for (const string& ctrl_dep : ctrl_deps) {
825           *new_add_node->add_input() = ctrl_dep;
826           ctx().node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
827         }
828 
829         // optimize new inner aggregation node
830         AddToOptimizationQueue(new_add_node);
831         // do not optimize the same node twice
832         rewritten_nodes_.insert(node->name());
833         *simplified_node_name = new_outer_node->name();
834       }
835     }
836     return Status::OK();
837   }
838 
839  private:
840   // Get a name for new outer node
OuterNodeName(const NodeDef * node,bool is_div) const841   string OuterNodeName(const NodeDef* node, bool is_div) const {
842     auto scope_and_name = ParseNodeScopeAndName(node->name());
843     return is_div ? OptimizedNodeName(scope_and_name, "Div")
844                   : OptimizedNodeName(scope_and_name, "Mul");
845   }
846 
847   // Get a name new inner Add node
InnerAddNodeName(const NodeDef * node) const848   string InnerAddNodeName(const NodeDef* node) const {
849     auto scope_and_name = ParseNodeScopeAndName(node->name());
850     return OptimizedNodeName(scope_and_name, "AddV2");
851   }
852 
853   // Determine the set of common factors if the input nodes are all Mul or
854   // Div nodes.
GetCommonFactors(const NodeDef * node,std::set<string> * common_factors,bool * common_factor_is_denominator,std::vector<string> * ctrl_deps) const855   Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
856                           bool* common_factor_is_denominator,
857                           std::vector<string>* ctrl_deps) const {
858     CHECK(common_factors->empty());
859     CHECK_NOTNULL(common_factor_is_denominator);
860     *common_factor_is_denominator = false;
861 
862     bool has_mul = false;
863     bool has_div = false;
864     for (int i = 0; i < node->input_size(); ++i) {
865       if (i > 0 && common_factors->empty()) break;
866       if (IsControlInput(node->input(i))) {
867         ctrl_deps->push_back(node->input(i));
868         continue;
869       }
870       NodeDef* input;
871       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
872 
873       if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
874           (IsAnyDiv(*input) && has_mul)) {
875         // Break if input is neither a Mul or Div, or if there are both Mul &
876         // Div Ops.
877         common_factors->clear();
878         break;
879       } else if (IsAnyDiv(*input)) {
880         has_div = true;
881         // In case of possible common dividers, we avoid hoisting out if any
882         // input is not float/double, since integer division is not distributive
883         // over addition.
884         const OpInfo::TensorProperties* properties0;
885         const OpInfo::TensorProperties* properties1;
886         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
887         TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
888         if (properties0->dtype() != DT_FLOAT &&
889             properties0->dtype() != DT_DOUBLE &&
890             properties1->dtype() != DT_FLOAT &&
891             properties1->dtype() != DT_DOUBLE) {
892           common_factors->clear();
893           break;
894         }
895       } else if (IsMul(*input)) {
896         has_mul = true;
897       }
898 
899       // We only focus on common factors from denominators if any Op is a
900       // Div.
901       std::set<string> factors_i =
902           has_mul ? std::set<string>{input->input(0), input->input(1)}
903                   : std::set<string>{input->input(1)};
904       if (i == 0) {
905         std::swap(*common_factors, factors_i);
906       } else {
907         std::set<string> intersection;
908         std::set_intersection(
909             factors_i.begin(), factors_i.end(), common_factors->begin(),
910             common_factors->end(),
911             std::inserter(intersection, intersection.begin()));
912         std::swap(*common_factors, intersection);
913       }
914       for (int i = 2; i < input->input_size(); ++i) {
915         ctrl_deps->push_back(input->input(i));
916       }
917     }
918 
919     *common_factor_is_denominator = has_div;
920     return Status::OK();
921   }
922 
923   // Gather up the non-shared factors (the y's in the example).
924   // Unless the aggregation is Add, we have to make sure that all the y's
925   // have the same shape since the other aggregation ops do not support
926   // broadcasting.
GetUniqueFactors(const NodeDef * node,const string & common_factor,const bool common_factor_is_denominator,bool * shapes_match,std::vector<string> * unique_factors) const927   Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
928                           const bool common_factor_is_denominator,
929                           bool* shapes_match,
930                           std::vector<string>* unique_factors) const {
931     *shapes_match = true;
932     unique_factors->reserve(node->input_size());
933 
934     for (int i = 0; i < node->input_size() && *shapes_match; ++i) {
935       const string& input = node->input(i);
936       if (IsControlInput(input)) {
937         break;
938       }
939       NodeDef* inner_node;
940       TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
941       const int unique_factor_index =
942           common_factor_is_denominator
943               ? 0
944               : (inner_node->input(0) == common_factor ? 1 : 0);
945       unique_factors->push_back(inner_node->input(unique_factor_index));
946       if (i > 0 && !IsAdd(*node)) {
947         const OpInfo::TensorProperties* lhs;
948         const OpInfo::TensorProperties* rhs;
949         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->front(), &lhs));
950         TF_RETURN_IF_ERROR(GetTensorProperties(unique_factors->back(), &rhs));
951         *shapes_match = ShapesSymbolicallyEqual(*lhs, *rhs);
952       }
953     }
954     return Status::OK();
955   }
956 
IsRewritten(const NodeDef * node) const957   bool IsRewritten(const NodeDef* node) const {
958     // if graph rewrite happens in multiple passes without graph pruning between
959     // them, it's possible that rewritten node already exists in a graph
960     return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
961            ctx().node_map->NodeExists(OuterNodeName(node, false)) ||
962            ctx().node_map->NodeExists(OuterNodeName(node, true)) ||
963            ctx().node_map->NodeExists(InnerAddNodeName(node));
964   }
965 
966   // keep names of the nodes that were optimized by this stage
967   std::unordered_set<string> rewritten_nodes_;
968 };
969 
970 // Binary associative ops can be re-ordered to minimize the number of broadcasts
971 // and the size of a temporary tensors.
972 //
973 // Example: [a, c] - scalars, [b, d] - matrices
974 //   @ - binary associative op (Add or Mul)
975 //   @* - broadcast
976 //
977 //           @                      @*
978 //        /     \                /      \
979 //      @*       @*      ->     @        @
980 //    /   \    /   \          /   \    /   \
981 //   a     b  c     d        a     c  b     d
982 class MinimizeBroadcasts : public ArithmeticNodesGroupOptimizerStage {
983  public:
MinimizeBroadcasts(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)984   explicit MinimizeBroadcasts(const GraphOptimizerContext& ctx,
985                               const ArithmeticOptimizerContext& ctx_ext)
986       : ArithmeticNodesGroupOptimizerStage("MinimizeBroadcasts", ctx, ctx_ext) {
987   }
988   ~MinimizeBroadcasts() override = default;
989 
IsSupported(const NodeDef * node) const990   bool IsSupported(const NodeDef* node) const override {
991     if (!IsBinaryAssociative(*node)) return false;
992 
993     if (IsMarkedWithAnyTag(*node, kMinimizeBroadcastsTag, kAddOpsRewriteTag))
994       return false;
995 
996     // has a symbolically defined shape with broadcastable inputs
997     const OpInfo::TensorProperties* properties;
998     Status has_properties = GetTensorProperties(node->name(), &properties);
999     return has_properties.ok() && ShapeIsSymbolicallyDefined(*properties) &&
1000            HasAllInputsBroadcastableToShape(*node, *properties);
1001   }
1002 
1003  protected:
IsBinaryAssociative(const NodeDef & node) const1004   bool IsBinaryAssociative(const NodeDef& node) const {
1005     return IsMul(node) || IsAdd(node);
1006   }
1007 
IsSameOp(const OptimizedNodesGroup & group,const NodeDef & node) const1008   bool IsSameOp(const OptimizedNodesGroup& group, const NodeDef& node) const {
1009     return group.root_node->op() == node.op();
1010   }
1011 
1012   // Check if a node can be absorbed by current OptimizedNodesGroup
IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup & group,const NodeDef & node) const1013   bool IsAbsorbableByOptimizedNodesGroup(const OptimizedNodesGroup& group,
1014                                          const NodeDef& node) const override {
1015     if (!IsSameOp(group, node)) {
1016       return false;
1017     }
1018     if (IsInPreserveSet(node)) {
1019       return false;
1020     }
1021     // Nodes optimized by AddOpsRewrite already have optimal broadcasts.
1022     if (IsMarkedWithAnyTag(node, kMinimizeBroadcastsTag, kAddOpsRewriteTag)) {
1023       return false;
1024     }
1025     if (IsDrivenByControlDependency(node) || DrivesControlDependency(node)) {
1026       return false;
1027     }
1028     if (!IsOnTheSameDevice(group, node)) {
1029       return false;
1030     }
1031     // Optimized nodes updated in place, and that would break the graph, if the
1032     // node has multiple output consumers
1033     if (NumNonControlOutputs(node, *ctx().node_map) != 1) {
1034       return false;
1035     }
1036     // All input shapes must be broadcastable to the node shape
1037     const OpInfo::TensorProperties* properties;
1038     Status has_properties = GetTensorProperties(node.name(), &properties);
1039     return has_properties.ok() &&
1040            HasAllInputsBroadcastableToShape(node, *properties);
1041   }
1042 
CountUniqueShapes(const std::vector<InputAndShape> & inputs)1043   std::size_t CountUniqueShapes(const std::vector<InputAndShape>& inputs) {
1044     std::set<string> sigs;
1045     for (const auto& ias : inputs) {
1046       sigs.insert(ShapeSignature(ias.shape));
1047     }
1048     return sigs.size();
1049   }
1050 
RewriteOptimizedNodesGroup(const OptimizedNodesGroup & group)1051   string RewriteOptimizedNodesGroup(const OptimizedNodesGroup& group) override {
1052     VLOG(2) << "Minimize broadcast: root=" << group.root_node->name()
1053             << " op=" << group.root_node->op()
1054             << " num_optimized_nodes=" << group.optimized_nodes.size();
1055 
1056     // Do not optimize any of the nodes that are part of this group.
1057     MarkAllMembersWithTag(group, kMinimizeBroadcastsTag);
1058 
1059     if (CountUniqueShapes(group.inputs) <= 1) {
1060       VLOG(3) << "Skip min-bcast group with single unique shape";
1061       // nothing to optimize when all shapes are the same
1062       return group.root_node->name();
1063     }
1064 
1065     auto num_nodes = /*root*/ 1 + group.optimized_nodes.size();
1066     auto num_inputs = group.inputs.size();
1067     CHECK_EQ(num_nodes, num_inputs - 1)
1068         << "Can't build a tree with " << num_inputs << " inputs, using "
1069         << num_nodes << "binary op nodes.";
1070 
1071     std::deque<InputAndShape> add_ops(group.inputs.begin(), group.inputs.end());
1072     std::deque<NodeDef*> optimized_nodes(group.optimized_nodes.begin(),
1073                                          group.optimized_nodes.end());
1074 
1075     // sort inputs by it's shape from smallest to largest
1076     std::stable_sort(add_ops.begin(), add_ops.end(),
1077                      [](const InputAndShape& lhs, const InputAndShape& rhs) {
1078                        return CompareSymbolicallyShapedTensorSizes(lhs.shape,
1079                                                                    rhs.shape);
1080                      });
1081 
1082     // If there is an odd number of inputs, last one is the largest, and we want
1083     // to attach it to the root node, to build a well balanced tree.
1084     std::deque<InputAndShape> add_ops_leftover;
1085     if (add_ops.size() % 2 != 0) {
1086       add_ops_leftover.push_back(add_ops.back());
1087       add_ops.pop_back();
1088     }
1089 
1090     // At this point it's guaranteed that add_ops have even number of inputs.
1091     do {
1092       const InputAndShape lhs = add_ops.front();
1093       add_ops.pop_front();
1094       const InputAndShape rhs = add_ops.front();
1095       add_ops.pop_front();
1096 
1097       NodeDef* node;
1098       if (!optimized_nodes.empty()) {
1099         // re-purpose optimized nodes to build a new tree
1100         node = optimized_nodes.back();
1101         optimized_nodes.pop_back();
1102       } else {
1103         // or use root node if none optimized nodes left
1104         node = group.root_node;
1105       }
1106       InputAndShape updated_node = UpdateInputs(lhs.input, rhs.input, node);
1107 
1108       // Pushing updated node to the back of a deque will create a wide and
1109       // short tree, pushing to the front will create a tall tree. We prefer to
1110       // get a wide tree, it minimizes the potential number of temporary tensors
1111       // required to keep in memory, though sometimes we can go up to prevent
1112       // propagating a broadcast from leaves to the root. Example:
1113       //
1114       // inputs: [s, s, s, M] (s - scalar, M - matrix)
1115       // @* - op with broadcast
1116       //
1117       //  (only push_back)           @*     (push_front first op)
1118       //                            /  \
1119       //       @*                  @    M
1120       //     /   \                / \
1121       //    @     @*      ->     @   s
1122       //   / \   / \            / \
1123       //  s   s s   M          s   s
1124       if (add_ops.size() >= 2 &&
1125           CompareSymbolicallyShapedTensorSizes(add_ops.at(0).shape,
1126                                                add_ops.at(1).shape)) {
1127         add_ops.push_front(updated_node);
1128       } else {
1129         add_ops.push_back(updated_node);
1130       }
1131     } while (add_ops.size() > 1);
1132     CHECK_EQ(1, add_ops.size());
1133 
1134     // attach the largest tensor to the root op
1135     if (!add_ops_leftover.empty()) {
1136       const InputAndShape lhs = add_ops.front();
1137       add_ops.pop_front();
1138       const InputAndShape rhs = add_ops_leftover.front();
1139       InputAndShape updated_node =
1140           UpdateInputs(lhs.input, rhs.input, group.root_node);
1141       add_ops.push_back(updated_node);
1142     }
1143 
1144     return add_ops.front().input;
1145   }
1146 
UpdateInputs(const string & input_0,const string & input_1,NodeDef * node)1147   InputAndShape UpdateInputs(const string& input_0, const string& input_1,
1148                              NodeDef* node) {
1149     string old_input_0 = node->input(0);
1150     string old_input_1 = node->input(1);
1151 
1152     // Update inputs only if they changed
1153     if (old_input_0 != input_0 || old_input_1 != input_1) {
1154       node->set_input(0, input_0);
1155       node->set_input(1, input_1);
1156       // Invalidate node properties (shape)
1157       ctx().graph_properties->ClearOutputProperties(node->name());
1158       ctx().graph_properties->ClearInputProperties(node->name());
1159       // Update the node map
1160       ctx().node_map->RemoveOutput(NodeName(old_input_0), node->name());
1161       ctx().node_map->RemoveOutput(NodeName(old_input_1), node->name());
1162       ctx().node_map->AddOutput(NodeName(input_0), node->name());
1163       ctx().node_map->AddOutput(NodeName(input_1), node->name());
1164       // Add updated node to optimization queue
1165       AddToOptimizationQueue(node);
1166     }
1167 
1168     TensorShapeProto shape;  // shape is not important at this point
1169     return InputAndShape(node->name(), shape);
1170   }
1171 };
1172 
1173 // Removes inverse transpose nodes
1174 class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
1175  public:
RemoveIdentityTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1176   explicit RemoveIdentityTranspose(const GraphOptimizerContext& ctx,
1177                                    const ArithmeticOptimizerContext& ctx_ext)
1178       : ArithmeticOptimizerStage("RemoveIdentityTranspose", ctx, ctx_ext) {}
1179   ~RemoveIdentityTranspose() override = default;
1180 
IsSupported(const NodeDef * node) const1181   bool IsSupported(const NodeDef* node) const override {
1182     return IsTranspose(*node) || IsConjugateTranspose(*node);
1183   }
1184 
TrySimplify(NodeDef * node,string * simplified_node_name)1185   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1186     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1187     NodeDef* tail = node;
1188     tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
1189                                     *ctx().nodes_to_preserve);
1190     NodeDef* first_transpose;
1191     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
1192 
1193     NodeDef* node_perm;
1194     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &node_perm));
1195     if (!IsConstant(*node_perm)) {
1196       return Status::OK();
1197     }
1198     std::vector<int64> node_perm_values;
1199     TF_RETURN_IF_ERROR(GetPermutation(*node_perm, &node_perm_values));
1200     if (first_transpose->op() == node->op()) {
1201       // Remove pairs of transposes that cancel each other.
1202       NodeDef* first_transpose_perm;
1203       TF_RETURN_IF_ERROR(
1204           GetInputNode(first_transpose->input(1), &first_transpose_perm));
1205       if (!IsConstant(*first_transpose_perm)) {
1206         return Status::OK();
1207       }
1208       std::vector<int64> first_transpose_perm_values;
1209       TF_RETURN_IF_ERROR(
1210           GetPermutation(*first_transpose_perm, &first_transpose_perm_values));
1211       if (AreInversePermutations(node_perm_values,
1212                                  first_transpose_perm_values)) {
1213         if (tail == node) {
1214           // Bypass adjacent pair.
1215           *simplified_node_name = first_transpose->input(0);
1216         } else {
1217           // Bypass pair connected through chain.
1218           tail->set_input(0, first_transpose->input(0));
1219           ctx().node_map->UpdateInput(tail->name(), first_transpose->name(),
1220                                       first_transpose->input(0));
1221           ForwardControlDependencies(tail, {first_transpose});
1222           *simplified_node_name = node->input(0);
1223         }
1224       }
1225     } else {
1226       // Remove simple identity transposes.
1227       if (IsIdentityPermutation(node_perm_values)) {
1228         if (IsConjugateTranspose(*node)) {
1229           const NodeScopeAndName transpose =
1230               ParseNodeScopeAndName(node->name());
1231           const string optimized_node_name = OptimizedNodeName(transpose);
1232           NodeDef* new_op = AddCopyNode(optimized_node_name, node);
1233           new_op->set_op("Conj");
1234           new_op->mutable_input()->RemoveLast();
1235           new_op->mutable_attr()->erase("Tperm");
1236           ForwardControlDependencies(new_op, {node});
1237           *simplified_node_name = new_op->name();
1238         } else {
1239           *simplified_node_name = node->input(0);
1240         }
1241       }
1242     }
1243     return Status::OK();
1244   }
1245 
1246  private:
GetPermutation(const NodeDef & node_perm,std::vector<int64> * perm64) const1247   Status GetPermutation(const NodeDef& node_perm,
1248                         std::vector<int64>* perm64) const {
1249     std::vector<int> perm32;
1250     if (ValuesFromConstNode(node_perm, &perm32)) {
1251       perm64->reserve(perm32.size());
1252       for (int val : perm32) {
1253         perm64->push_back(static_cast<int64>(val));
1254       }
1255       return Status::OK();
1256     }
1257     if (ValuesFromConstNode(node_perm, perm64)) {
1258       return Status::OK();
1259     }
1260     return errors::InvalidArgument("Couldn't extract permutation from ",
1261                                    node_perm.name());
1262   }
1263 
AreInversePermutations(const std::vector<int64> & a,const std::vector<int64> & b)1264   bool AreInversePermutations(const std::vector<int64>& a,
1265                               const std::vector<int64>& b) {
1266     if (a.size() != b.size()) {
1267       return false;
1268     }
1269     for (int i = 0, end = a.size(); i < end; ++i) {
1270       if (a[b[i]] != i) {
1271         return false;
1272       }
1273     }
1274     return true;
1275   }
1276 
IsIdentityPermutation(const std::vector<int64> & perm)1277   bool IsIdentityPermutation(const std::vector<int64>& perm) {
1278     for (int64_t i = 0, end = perm.size(); i < end; ++i) {
1279       if (i != perm[i]) {
1280         return false;
1281       }
1282     }
1283     return true;
1284   }
1285 };
1286 
1287 // An involution is an element-wise function f(x) that is its own inverse,
1288 // i.e. f(f(x)) = x. If we can find a chain of ops
1289 //   f->op1->op2->...opn->f
1290 // where op1 through opn preserve the values of their inputs, we can remove
1291 // the two instances of the involution from the graph, since they cancel
1292 // each other.
1293 class RemoveInvolution : public ArithmeticOptimizerStage {
1294  public:
RemoveInvolution(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1295   explicit RemoveInvolution(const GraphOptimizerContext& ctx,
1296                             const ArithmeticOptimizerContext& ctx_ext)
1297       : ArithmeticOptimizerStage("RemoveInvolution", ctx, ctx_ext) {}
1298   ~RemoveInvolution() override = default;
1299 
IsSupported(const NodeDef * node) const1300   bool IsSupported(const NodeDef* node) const override {
1301     return IsInvolution(*node);
1302   }
1303 
TrySimplify(NodeDef * node,string * simplified_node_name)1304   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1305     NodeDef* tail = GetTailOfValuePreservingChain(*node, *ctx().node_map,
1306                                                   *ctx().nodes_to_preserve);
1307 
1308     NodeDef* involution;
1309     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &involution));
1310 
1311     if (involution->op() == node->op()) {
1312       // Skip both *node and *involution since they cancel each other.
1313       if (tail == node) {
1314         // The two nodes to eliminate are adjacent.
1315         *simplified_node_name = involution->input(0);
1316       } else {
1317         tail->set_input(0, involution->input(0));
1318         ctx().node_map->UpdateInput(tail->name(), involution->name(),
1319                                     involution->input(0));
1320         *simplified_node_name = node->input(0);
1321       }
1322     }
1323 
1324     return Status::OK();
1325   }
1326 };
1327 
1328 // Remove redundant Bitcasts.
1329 // 1) Remove Bitcast whose source type and destination type are equal
1330 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1331 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
1332  public:
RemoveRedundantBitcastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1333   explicit RemoveRedundantBitcastStage(
1334       const GraphOptimizerContext& ctx,
1335       const ArithmeticOptimizerContext& ctx_ext)
1336       : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx, ctx_ext) {}
1337   ~RemoveRedundantBitcastStage() override = default;
1338 
IsSupported(const NodeDef * node) const1339   bool IsSupported(const NodeDef* node) const override {
1340     return IsBitcast(*node);
1341   }
1342 
TrySimplify(NodeDef * node,string * simplified_node_name)1343   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1344     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1345 
1346     // Bypass Bitcast whose source type and destination type are equal.
1347     AttrSlice attrs(*node);
1348     DataType input_type;
1349     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &input_type));
1350     DataType output_type;
1351     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "type", &output_type));
1352     if ((input_type == output_type) && !IsInPreserveSet(*node)) {
1353       *simplified_node_name = node->input(0);
1354       return Status::OK();
1355     }
1356 
1357     NodeDef* bitcast;
1358     TF_RETURN_IF_ERROR(GetInputNode(node->name(), &bitcast));
1359     NodeDef* operand;
1360     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &operand));
1361 
1362     if (IsBitcast(*operand) && !IsInPreserveSet(*operand)) {
1363       AttrSlice operand_attrs(*operand);
1364       DataType operand_input_type;
1365       TF_RETURN_IF_ERROR(GetNodeAttr(operand_attrs, "T", &operand_input_type));
1366       // Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
1367       bitcast->set_input(0, operand->input(0));
1368       SetDataTypeToAttr(operand_input_type, "T", bitcast);
1369       ctx().node_map->UpdateInput(bitcast->name(), bitcast->input(0),
1370                                   operand->input(0));
1371       AddToOptimizationQueue(bitcast);
1372       *simplified_node_name = bitcast->name();
1373     }
1374 
1375     return Status::OK();
1376   }
1377 };
1378 
1379 // Remove Casts whose source type and destination type are equal.
1380 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
1381  public:
RemoveRedundantCastStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1382   explicit RemoveRedundantCastStage(const GraphOptimizerContext& ctx,
1383                                     const ArithmeticOptimizerContext& ctx_ext)
1384       : ArithmeticOptimizerStage("RemoveRedundantCast", ctx, ctx_ext) {}
1385   ~RemoveRedundantCastStage() override = default;
1386 
IsSupported(const NodeDef * node) const1387   bool IsSupported(const NodeDef* node) const override {
1388     return IsCast(*node) && !IsInPreserveSet(*node);
1389   }
1390 
TrySimplify(NodeDef * node,string * simplified_node_name)1391   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1392     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
1393 
1394     // Bypass Cast whose source type and destination type are equal.
1395     AttrSlice attrs(*node);
1396     DataType input_type;
1397     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "SrcT", &input_type));
1398     DataType output_type;
1399     TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "DstT", &output_type));
1400     if (input_type == output_type) {
1401       *simplified_node_name = node->input(0);
1402     }
1403     return Status::OK();
1404   }
1405 };
1406 
1407 class RemoveNegationStage : public ArithmeticOptimizerStage {
1408  public:
RemoveNegationStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1409   explicit RemoveNegationStage(const GraphOptimizerContext& ctx,
1410                                const ArithmeticOptimizerContext& ctx_ext)
1411       : ArithmeticOptimizerStage("RemoveNegation", ctx, ctx_ext) {}
1412   ~RemoveNegationStage() override = default;
1413 
IsSupported(const NodeDef * node) const1414   bool IsSupported(const NodeDef* node) const override {
1415     return (IsAdd(*node) || IsSub(*node)) && !IsInPreserveSet(*node);
1416   }
1417 
TrySimplify(NodeDef * node,string * simplified_node_name)1418   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1419     NodeDef* x;
1420     NodeDef* y;
1421     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1422     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1423     bool updated = false;
1424     if (IsNeg(*y)) {
1425       // a - (-b) = a + b or  a + (-b) = a - b
1426       ForwardControlDependencies(node, {y});
1427       ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
1428       node->set_op(IsAdd(*node) ? "Sub" : "AddV2");
1429       node->set_input(1, y->input(0));
1430       updated = true;
1431     } else if (IsAdd(*node) && IsNeg(*x)) {
1432       // (-a) + b = b - a
1433       ForwardControlDependencies(node, {x});
1434       ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
1435       node->set_op("Sub");
1436       node->mutable_input()->SwapElements(0, 1);
1437       node->set_input(1, x->input(0));
1438       updated = true;
1439     }
1440     if (updated) {
1441       AddToOptimizationQueue(node);
1442     }
1443     return Status::OK();
1444   }
1445 };
1446 
1447 class RemoveLogicalNotStage : public ArithmeticOptimizerStage {
1448  public:
RemoveLogicalNotStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1449   explicit RemoveLogicalNotStage(const GraphOptimizerContext& ctx,
1450                                  const ArithmeticOptimizerContext& ctx_ext)
1451       : ArithmeticOptimizerStage("RemoveLogicalNot", ctx, ctx_ext) {}
1452   ~RemoveLogicalNotStage() override = default;
1453 
IsSupported(const NodeDef * node) const1454   bool IsSupported(const NodeDef* node) const override {
1455     return IsLogicalNot(*node) && !IsInPreserveSet(*node);
1456   }
1457 
TrySimplify(NodeDef * node,string * simplified_node_name)1458   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1459     const string node_name = node->name();
1460     NodeDef* input;
1461     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1462     if (IsInPreserveSet(*input) ||
1463         NumNonControlOutputs(*input, *ctx().node_map) > 1) {
1464       return Status::OK();
1465     }
1466     string new_op;
1467     if (IsEqual(*input)) {
1468       new_op = "NotEqual";
1469     } else if (IsNotEqual(*input)) {
1470       new_op = "Equal";
1471     } else if (IsLess(*input)) {
1472       new_op = "GreaterEqual";
1473     } else if (IsLessEqual(*input)) {
1474       new_op = "Greater";
1475     } else if (IsGreater(*input)) {
1476       new_op = "LessEqual";
1477     } else if (IsGreaterEqual(*input)) {
1478       new_op = "Less";
1479     }
1480     if (!new_op.empty()) {
1481       input->set_op(new_op);
1482       *simplified_node_name = input->name();
1483     }
1484     return Status::OK();
1485   }
1486 };
1487 
1488 // This optimization hoists the common prefix of unary ops of the inputs to
1489 // concat out of the concat, for example:
1490 //    Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))])
1491 // becomes
1492 //    Exp(Sin(Concat([x, y, z]))).
1493 // Similarly, it will hoist the common postfix of unary ops into Split or
1494 // SplitV nodes, for example:
1495 //    [Exp(Sin(y)) for y in Split(x)]
1496 // becomes
1497 //    [y for y in Split(Exp(Sin(x))]
1498 //
1499 // TODO(rmlarsen): Support casting. We would have to change the type attribute
1500 // on the concat/split node.
1501 // TODO(rmlarsen): Handle Enter/Exit.
1502 class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
1503  public:
HoistCWiseUnaryChainsStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1504   explicit HoistCWiseUnaryChainsStage(const GraphOptimizerContext& ctx,
1505                                       const ArithmeticOptimizerContext& ctx_ext)
1506       : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
1507 
1508   ~HoistCWiseUnaryChainsStage() override = default;
1509 
1510   struct ChainLink {
1511     ChainLink() = default;
ChainLinktensorflow::grappler::__anon912b57ca0111::HoistCWiseUnaryChainsStage::ChainLink1512     ChainLink(NodeDef* _node, int _port_origin)
1513         : node(_node), port_origin(_port_origin) {}
1514     NodeDef* node;    // Node in a chain.
1515     int port_origin;  // Port on concat/split node from which this chain
1516                       // originates.
1517 
operator <tensorflow::grappler::__anon912b57ca0111::HoistCWiseUnaryChainsStage::ChainLink1518     bool operator<(const ChainLink& other) const {
1519       if (port_origin < other.port_origin) {
1520         return true;
1521       } else if (port_origin > other.port_origin) {
1522         return false;
1523       } else {
1524         return node->name() < other.node->name();
1525       }
1526     }
1527   };
1528 
1529   // We use an ordinary set sorted on port and node name, so the order, and
1530   // hence the node name used for the hoisted chain, will be deterministic.
1531   using ChainLinkSet = std::set<ChainLink>;
1532 
IsSupported(const NodeDef * node) const1533   bool IsSupported(const NodeDef* node) const override {
1534     if (IsInPreserveSet(*node)) return false;
1535     if (IsConcat(*node) && node->attr().count("N") != 0) {
1536       const int n = node->attr().at("N").i();
1537       return n > 1 && FirstNInputsAreUnique(*node, n);
1538     } else if ((IsSplit(*node) || IsSplitV(*node)) &&
1539                node->attr().count("num_split") != 0) {
1540       const int num_split = node->attr().at("num_split").i();
1541       if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
1542         // TODO(rmlarsen): Remove this constraint when we have optimizations
1543         // in place for merging slices into splits.
1544         return false;
1545       }
1546       if (NumControlOutputs(*node, *ctx().node_map) > 0) {
1547         // TODO(ezhulenev): Unary ops after Split might have a control path to
1548         // the Split node, and we currently do not properly handle cycles.
1549         return false;
1550       }
1551       return num_split > 1 && !IsAlreadyOptimized(*node);
1552     }
1553     return false;
1554   }
1555 
TrySimplify(NodeDef * node,string * simplified_node_name)1556   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1557     node_is_concat_ = IsConcat(*node);
1558     int prefix_length;
1559     std::set<string> ctrl_inputs;
1560     ChainLinkSet tails;
1561     TF_RETURN_IF_ERROR(
1562         FindCommonUnaryOpChain(*node, &prefix_length, &tails, &ctrl_inputs));
1563     if (prefix_length > 0 && !tails.empty()) {
1564       TF_RETURN_IF_ERROR(
1565           HoistUnaryOpChain(prefix_length, tails, &ctrl_inputs, node));
1566     }
1567     return Status::OK();
1568   }
1569 
1570  private:
FirstNInputsAreUnique(const NodeDef & node,int n) const1571   bool FirstNInputsAreUnique(const NodeDef& node, int n) const {
1572     if (n > node.input_size()) return false;
1573     absl::flat_hash_set<string> unique_inputs;
1574     const int start = node.op() == "Concat" ? 1 : 0;
1575     const int end = start + n;
1576     for (int i = start; i < end; ++i) {
1577       unique_inputs.insert(node.input(i));
1578     }
1579     int unique_input_size = unique_inputs.size();
1580     return unique_input_size == n;
1581   }
1582 
1583   // Returns the length of the common unary chain of ops that can be
1584   // hoisted to the other side of concat or split.
FindCommonUnaryOpChain(const NodeDef & root_node,int * prefix_length,ChainLinkSet * tails,std::set<string> * ctrl_inputs) const1585   Status FindCommonUnaryOpChain(const NodeDef& root_node, int* prefix_length,
1586                                 ChainLinkSet* tails,
1587                                 std::set<string>* ctrl_inputs) const {
1588     *prefix_length = 0;
1589     // Follow the chains starting at each concat input or split output as long
1590     // as all the following conditions hold:
1591     //   1. The ops in all chains are the same.
1592     //   2. The ops are unary elementwise op.
1593     //   3. The op output has only a single consumer (concat only).
1594     ChainLinkSet cur_tails;
1595     TF_RETURN_IF_ERROR(InitializeChains(root_node, &cur_tails));
1596     if (cur_tails.size() < 2) {
1597       return Status::OK();
1598     }
1599     ctrl_inputs->clear();
1600     bool stop = false;
1601     while (!stop && !cur_tails.empty() &&
1602            OpsAreSafeToHoist(root_node, cur_tails)) {
1603       // We found one more link that can be hoisted.
1604       ++(*prefix_length);
1605       tails->swap(cur_tails);
1606       GatherControlInputs(ctrl_inputs, *tails);
1607 
1608       // Advance tail pointers to the next level.
1609       TF_RETURN_IF_ERROR(AdvanceTails(*tails, &cur_tails, &stop));
1610     }
1611     return Status::OK();
1612   }
1613 
1614   // Hoists the chains to the other side of concat or split and attaches the
1615   // control inputs gathered from them to the concat or split node.
HoistUnaryOpChain(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * root_node)1616   Status HoistUnaryOpChain(const int prefix_length, const ChainLinkSet& tails,
1617                            std::set<string>* ctrl_inputs, NodeDef* root_node) {
1618     VLOG(3) << "Hoist unary op chain:"
1619             << " root=" << root_node->DebugString()
1620             << " prefix_length=" << prefix_length << " ctrl_inputs=["
1621             << absl::StrJoin(*ctrl_inputs, ", ") << "]";
1622 
1623     if (tails.empty()) {
1624       return Status::OK();
1625     }
1626     AddToOptimizationQueue(root_node);
1627     optimized_nodes_.insert(root_node->name());
1628     if (node_is_concat_) {
1629       AddControlInputs(ctrl_inputs, root_node);
1630       return HoistChainForConcat(prefix_length, tails, root_node);
1631     } else {
1632       return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
1633     }
1634   }
1635 
GatherControlInputs(std::set<string> * ctrl_inputs,const ChainLinkSet & ops) const1636   void GatherControlInputs(std::set<string>* ctrl_inputs,
1637                            const ChainLinkSet& ops) const {
1638     for (const auto& link : ops) {
1639       const NodeDef* node = link.node;
1640       for (int i = node->input_size() - 1; i >= 0; --i) {
1641         const string& input = node->input(i);
1642         if (!IsControlInput(input)) break;
1643         ctrl_inputs->insert(input);
1644       }
1645     }
1646   }
1647 
AddControlInputs(std::set<string> * new_ctrl_inputs,NodeDef * node) const1648   void AddControlInputs(std::set<string>* new_ctrl_inputs,
1649                         NodeDef* node) const {
1650     for (int i = node->input_size() - 1; i >= 0; --i) {
1651       const string& existing_input = node->input(i);
1652       if (!IsControlInput(existing_input)) break;
1653       new_ctrl_inputs->erase(existing_input);
1654     }
1655     for (const string& new_input : *new_ctrl_inputs) {
1656       ctx().node_map->AddOutput(NodeName(new_input), node->name());
1657       node->add_input(new_input);
1658     }
1659   }
1660 
InitializeChains(const NodeDef & node,ChainLinkSet * tails) const1661   Status InitializeChains(const NodeDef& node, ChainLinkSet* tails) const {
1662     if (node_is_concat_) {
1663       // Handle concat nodes by looking backwards in the graph.
1664       TF_RETURN_IF_ERROR(CheckAttrExists(node, "N"));
1665       const int n = node.attr().at("N").i();
1666       const int start = node.op() == "Concat" ? 1 : 0;
1667       const int end = start + n;
1668       if (end > node.input_size()) {
1669         return errors::FailedPrecondition("Got attr N=", n,
1670                                           " without enough inputs.");
1671       }
1672       // Set up tail pointers to point to the immediate inputs to Concat.
1673       for (int input_port = start; input_port < end; ++input_port) {
1674         if (IsControlInput(node.input(input_port))) {
1675           return errors::FailedPrecondition(
1676               "Got control input ", node.input(input_port),
1677               " where normal input was expected.");
1678         }
1679         NodeDef* tail;
1680         TF_RETURN_IF_ERROR(GetInputNode(node.input(input_port), &tail));
1681         tails->insert(ChainLink(tail, input_port));
1682       }
1683       return Status::OK();
1684     } else {
1685       // Handle split nodes by looking forwards in the graph.
1686       const auto& outputs = ctx().node_map->GetOutputs(node.name());
1687       for (NodeDef* output : outputs) {
1688         if (output->input_size() == 0 || IsControlInput(output->input(0))) {
1689           continue;
1690         }
1691         TensorId tensor_id = ParseTensorName(output->input(0));
1692         if (tensor_id.node() == node.name()) {
1693           tails->insert(ChainLink(output, tensor_id.index()));
1694         } else {
1695           // This output node has a non-control input other than the split node,
1696           // abort.
1697           tails->clear();
1698           return Status::OK();
1699         }
1700       }
1701     }
1702     return Status::OK();
1703   }
1704 
OpsAreSafeToHoist(const NodeDef & root_node,const ChainLinkSet & ops) const1705   bool OpsAreSafeToHoist(const NodeDef& root_node,
1706                          const ChainLinkSet& ops) const {
1707     if (ops.empty()) return true;
1708     const NodeDef* op0 = ops.begin()->node;
1709     if (ModifiesFrameInfo(*op0) || !IsUnaryElementWise(*op0)) return false;
1710     for (const auto& link : ops) {
1711       const NodeDef* op = link.node;
1712       if (op->device() != root_node.device() || op->op() != op0->op() ||
1713           IsInPreserveSet(*op)) {
1714         return false;
1715       }
1716       if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
1717         // TODO(rmlarsen): Allow outgoing control edges.
1718         return false;
1719       }
1720       // Do not hoist Relu if it can be fused with its predecessors. This is
1721       // important because remapping runs after arithmetic.
1722       if (IsRelu(*op) || IsRelu6(*op)) {
1723         NodeDef* operand = nullptr;
1724         if (!GetInputNode(op->input(0), &operand).ok()) {
1725           return false;
1726         }
1727         if (IsFusedBatchNorm(*operand) || IsBiasAdd(*operand)) {
1728           return false;
1729         }
1730       }
1731     }
1732     return true;
1733   }
1734 
AdvanceTails(const ChainLinkSet & tails,ChainLinkSet * new_tails,bool * stop) const1735   Status AdvanceTails(const ChainLinkSet& tails, ChainLinkSet* new_tails,
1736                       bool* stop) const {
1737     *stop = true;
1738     new_tails->clear();
1739     for (const auto& link : tails) {
1740       const NodeDef* tail = link.node;
1741       if (node_is_concat_) {
1742         if (tail->input_size() == 0 || IsControlInput(tail->input(0))) {
1743           return Status::OK();
1744         }
1745         NodeDef* new_tail;
1746         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &new_tail));
1747         // Remember original port.
1748         new_tails->insert(ChainLink(new_tail, link.port_origin));
1749       } else {
1750         for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
1751           const TensorId tensor = ParseTensorName(new_tail->input(0));
1752           if (tensor.node() != tail->name()) {
1753             return Status::OK();
1754           }
1755           // Skip control outputs.
1756           if (tensor.index() >= 0) {
1757             // Remember original port.
1758             new_tails->insert(ChainLink(new_tail, link.port_origin));
1759           }
1760         }
1761       }
1762     }
1763     *stop = false;
1764     return Status::OK();
1765   }
1766 
HoistChainForConcat(const int prefix_length,const ChainLinkSet & tails,NodeDef * concat_node)1767   Status HoistChainForConcat(const int prefix_length, const ChainLinkSet& tails,
1768                              NodeDef* concat_node) {
1769     const string& concat_name = concat_node->name();
1770     const int first_input = concat_node->op() == "Concat" ? 1 : 0;
1771     for (const auto& link : tails) {
1772       NodeDef* tail = CHECK_NOTNULL(link.node);
1773       const int concat_port = link.port_origin;
1774       CHECK_GE(concat_port, 0);
1775       CHECK_LT(concat_port, concat_node->input_size());
1776       const string concat_input = concat_node->input(concat_port);
1777       // Hook the node following tail directly into the concat node.
1778       const string tail_input = tail->input(0);
1779       concat_node->set_input(concat_port, tail_input);
1780       ctx().node_map->UpdateInput(concat_name, concat_input, tail_input);
1781 
1782       if (concat_port == first_input) {
1783         // Update the consumers of concat to consume the end of the chain
1784         // instead.
1785         TF_RETURN_IF_ERROR(UpdateConsumers(concat_node, concat_input));
1786         // Reuse nodes in the first chain to process output of concat.
1787         tail->set_input(0, concat_name);
1788         ctx().node_map->UpdateInput(tail->name(), tail_input, concat_name);
1789       }
1790     }
1791     return Status::OK();
1792   }
1793 
HoistChainForSplit(const int prefix_length,const ChainLinkSet & tails,std::set<string> * ctrl_inputs,NodeDef * split_node)1794   Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
1795                             std::set<string>* ctrl_inputs,
1796                             NodeDef* split_node) {
1797     // Create a new chain before the split node to process the input tensor.
1798     const string& split_name = split_node->name();
1799     auto root_scope_and_name = ParseNodeScopeAndName(split_name);
1800 
1801     // We use the first tail node in the set as a template to get the list of
1802     // ops to apply (starting from the end).
1803     NodeDef* cur_tail = tails.begin()->node;
1804     NodeDef* cur_copy = AddCopyNode(
1805         OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1806     cur_copy->clear_input();
1807 
1808     // Update the split to take its input from the tail of the new chain.
1809     const int value_slot = split_node->op() == "SplitV" ? 0 : 1;
1810     const string orig_input = split_node->input(value_slot);
1811     split_node->set_input(value_slot, cur_copy->name());
1812     ctx().node_map->UpdateInput(split_node->name(), orig_input,
1813                                 cur_copy->name());
1814     TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1815 
1816     // Now walk backwards creating the rest of the chain.
1817     while (cur_tail != split_node) {
1818       NodeDef* new_copy = AddCopyNode(
1819           OptimizedNodeName(root_scope_and_name, cur_tail->name()), cur_tail);
1820       new_copy->clear_input();
1821       cur_copy->add_input(new_copy->name());
1822       ctx().node_map->AddOutput(new_copy->name(), cur_copy->name());
1823       cur_copy = new_copy;
1824       TF_RETURN_IF_ERROR(GetInputNode(cur_tail->input(0), &cur_tail));
1825     }
1826     // Connect the original input to the head of the new chain.
1827     cur_copy->add_input(orig_input);
1828     ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
1829                                  cur_copy->name());
1830     // Make sure all the control inputs are satisfied before running the first
1831     // node in the new chain.
1832     AddControlInputs(ctrl_inputs, cur_copy);
1833 
1834     // Connect all consumers of the tail nodes directly to the
1835     // output port of Split from which the chain started.
1836     for (const auto& link : tails) {
1837       TF_RETURN_IF_ERROR(UpdateConsumers(
1838           link.node, link.port_origin == 0
1839                          ? split_name
1840                          : strings::StrCat(split_name, ":", link.port_origin)));
1841     }
1842     return Status::OK();
1843   }
1844 
IsAlreadyOptimized(const NodeDef & node) const1845   bool IsAlreadyOptimized(const NodeDef& node) const {
1846     return optimized_nodes_.find(node.name()) != optimized_nodes_.end();
1847   }
1848 
1849  private:
1850   bool node_is_concat_;
1851   std::unordered_set<string> optimized_nodes_;
1852 };
1853 
1854 class RemoveIdempotentStage : public ArithmeticOptimizerStage {
1855  public:
RemoveIdempotentStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1856   explicit RemoveIdempotentStage(const GraphOptimizerContext& ctx,
1857                                  const ArithmeticOptimizerContext& ctx_ext)
1858       : ArithmeticOptimizerStage("RemoveIdempotent", ctx, ctx_ext) {}
1859   ~RemoveIdempotentStage() override = default;
1860 
IsSupported(const NodeDef * node) const1861   bool IsSupported(const NodeDef* node) const override {
1862     return node->input_size() == 1 && IsIdempotent(*node) &&
1863            !IsInPreserveSet(*node);
1864   }
1865 
TrySimplify(NodeDef * node,string * simplified_node_name)1866   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1867     NodeDef* input;
1868     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
1869     if (input->op() == node->op() && input->device() == node->device()) {
1870       *simplified_node_name = node->input(0);
1871     }
1872     return Status::OK();
1873   }
1874 };
1875 
1876 // Performs the conversion:
1877 // Div(x, Sqrt(y)) => Mul(x, Rsqrt(y))
1878 // TODO(srjoglekar): Generalize to optimize cases like (x / pow(y, z)).
1879 class SqrtDivToRsqrtMulStage : public ArithmeticOptimizerStage {
1880  public:
SqrtDivToRsqrtMulStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1881   explicit SqrtDivToRsqrtMulStage(const GraphOptimizerContext& ctx,
1882                                   const ArithmeticOptimizerContext& ctx_ext)
1883       : ArithmeticOptimizerStage("SqrtDivToRsqrtMul", ctx, ctx_ext) {}
1884   ~SqrtDivToRsqrtMulStage() override = default;
1885 
IsSupported(const NodeDef * node) const1886   bool IsSupported(const NodeDef* node) const override {
1887     // Note: div_no_nan(a, sqrt(b)) => mul_no_nan(a, rsqrt(b))
1888     // for b == 0 would result in a / Inf instead of 0.
1889     return IsAnyDiv(*node) && !IsDivNoNan(*node) && !IsFloorDiv(*node);
1890   }
1891 
TrySimplify(NodeDef * node,string * simplified_node_name)1892   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1893     NodeDef* y;
1894     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
1895     // Optimize only if divisor is a Sqrt whose output is not being consumed
1896     // elsewhere.
1897     if (IsSqrt(*y) && !IsInPreserveSet(*y) &&
1898         (NumNonControlOutputs(*y, *ctx().node_map) == 1)) {
1899       if (IsXdivy(*node)) {
1900         // xdivy(a, sqrt(b)) => mul_no_nan(rsqrt(b), a)
1901         node->set_op("MulNoNan");
1902         node->mutable_input()->SwapElements(0, 1);
1903       } else {
1904         // div(a, sqrt(b)) => mul(a, rsqrt(b))
1905         node->set_op("Mul");
1906       }
1907       y->set_op("Rsqrt");
1908       AddToOptimizationQueue(node);
1909       AddToOptimizationQueue(y);
1910     }
1911     return Status::OK();
1912   }
1913 };
1914 
1915 // Performs the following conversion for real types:
1916 //   Square(Sub(x, y)) => Identity(SquaredDifference(x, y) )
1917 class FuseSquaredDiffStage : public ArithmeticOptimizerStage {
1918  public:
FuseSquaredDiffStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1919   explicit FuseSquaredDiffStage(const GraphOptimizerContext& ctx,
1920                                 const ArithmeticOptimizerContext& ctx_ext)
1921       : ArithmeticOptimizerStage("FuseSquaredDiffStage", ctx, ctx_ext) {}
1922   ~FuseSquaredDiffStage() override = default;
1923 
IsSupported(const NodeDef * node) const1924   bool IsSupported(const NodeDef* node) const override {
1925     return IsSquare(*node);
1926   }
1927 
TrySimplify(NodeDef * node,string * simplified_node_name)1928   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1929     NodeDef* b;
1930     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &b));
1931     // Optimize only if base is a Sub whose output is not being consumed
1932     // elsewhere.
1933     if (IsSub(*b) && !IsInPreserveSet(*b) &&
1934         (NumNonControlOutputs(*b, *ctx().node_map) == 1)) {
1935       // For complex, SquaredDiff computes conj(x-y)*(x-y), so this rewrite is
1936       // invalid.
1937       const DataType type = GetDataTypeFromAttr(*b, "T");
1938       if ((type == DT_COMPLEX64) || (type == DT_COMPLEX128))
1939         return Status::OK();
1940       node->set_op("Identity");
1941       b->set_op("SquaredDifference");
1942       AddToOptimizationQueue(node);
1943       AddToOptimizationQueue(b);
1944     }
1945     return Status::OK();
1946   }
1947 };
1948 
1949 // Performs the conversion:
1950 // Log(Softmax(x)) => LogSoftmax(x)
1951 class LogSoftmaxStage : public ArithmeticOptimizerStage {
1952  public:
LogSoftmaxStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1953   explicit LogSoftmaxStage(const GraphOptimizerContext& ctx,
1954                            const ArithmeticOptimizerContext& ctx_ext)
1955       : ArithmeticOptimizerStage("LogSoftmaxStage", ctx, ctx_ext) {}
1956   ~LogSoftmaxStage() override = default;
1957 
IsSupported(const NodeDef * node) const1958   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
1959 
TrySimplify(NodeDef * node,string * simplified_node_name)1960   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
1961     NodeDef* x;
1962     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
1963     // Optimize only if arg is a Softmax whose output is not being consumed
1964     // elsewhere.
1965     if (IsSoftmax(*x) && !IsInPreserveSet(*x) &&
1966         (NumNonControlOutputs(*x, *ctx().node_map) == 1)) {
1967       // Log(Softmax(x)) => LogSoftmax(Identity(x))
1968       node->set_op("LogSoftmax");
1969       x->set_op("Identity");
1970       AddToOptimizationQueue(node);
1971       AddToOptimizationQueue(x);
1972     }
1973     return Status::OK();
1974   }
1975 };
1976 
1977 // Bypass redundant reshape nodes:
1978 //
1979 //   Reshape                    Reshape  <-+
1980 //      ^                                  |
1981 //      |                                  |
1982 //   Reshape       becomes      Reshape    |
1983 //      ^                                  |
1984 //      |                                  |
1985 //    input                      input  ---+
1986 //
1987 // Additionally,  Reshape and BroadcastTo nodes where the
1988 // input and target shapes are equal are bypassed.
1989 //
1990 class RemoveRedundantReshapeOrBroadcastTo : public ArithmeticOptimizerStage {
1991  public:
RemoveRedundantReshapeOrBroadcastTo(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)1992   explicit RemoveRedundantReshapeOrBroadcastTo(
1993       const GraphOptimizerContext& ctx,
1994       const ArithmeticOptimizerContext& ctx_ext)
1995       : ArithmeticOptimizerStage("RemoveRedundantReshapeOrBroadcastTo", ctx,
1996                                  ctx_ext) {}
1997   ~RemoveRedundantReshapeOrBroadcastTo() override = default;
1998 
IsSupported(const NodeDef * node) const1999   bool IsSupported(const NodeDef* node) const override {
2000     return IsReshape(*node) || IsBroadcastTo(*node);
2001   }
2002 
2003   // TODO(rmlarsen): Handle unary ops with multiple outputs.
TrySimplify(NodeDef * node,string * simplified_node_name)2004   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2005     // 1. If the reshape is a no-op, forward its input to its consumers, unless
2006     // it anchors a control dependency since we want to make sure that control
2007     // dependency is triggered.
2008     if (!IsInPreserveSet(*node) && InputMatchesTargetShape(*node) &&
2009         !HasControlInputs(*node)) {
2010       *simplified_node_name = node->input(0);
2011       return Status::OK();
2012     }
2013 
2014     // 2. Bypass reshape followed by reshape, possibly separated by a simple
2015     // chain of unary elementwise ops that are not outputs.
2016     if (IsReshape(*node)) {
2017       bool skip = false;
2018       gtl::InlinedVector<const NodeDef*, 4> nodes_in_chain;
2019       const auto predicate_fn = [this, node, &skip,
2020                                  &nodes_in_chain](const NodeDef& input) {
2021         nodes_in_chain.push_back(&input);
2022         if ((input.name() != node->name() &&
2023              NumNonControlOutputs(input, *ctx().node_map) > 1) ||
2024             IsInPreserveSet(input) || ModifiesFrameInfo(input)) {
2025           skip = true;
2026           return false;
2027         }
2028         return IsUnaryElementWise(input);
2029       };
2030 
2031       // Walk up the input chain until we find a node that is not unary
2032       // element-wise. If it is another Reshape node, we can bypass it.
2033       NodeDef* tail =
2034           GetTailOfChain(*node, *ctx().node_map,
2035                          /*follow_control_input*/ false, predicate_fn);
2036 
2037       if (!skip && tail != nullptr && !IsInPreserveSet(*tail)) {
2038         NodeDef* reshape_to_bypass;
2039         TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &reshape_to_bypass));
2040         if (reshape_to_bypass == nullptr ||
2041             (!IsReshape(*reshape_to_bypass) ||
2042              NumNonControlOutputs(*reshape_to_bypass, *ctx().node_map) > 1 ||
2043              IsInPreserveSet(*reshape_to_bypass))) {
2044           return Status::OK();
2045         }
2046         // Clearing invalid shape inference results of nodes in chain.
2047         for (const NodeDef* node_in_chain : nodes_in_chain) {
2048           ctx().graph_properties->ClearInputProperties(node_in_chain->name());
2049           if (node_in_chain != node) {
2050             ctx().graph_properties->ClearOutputProperties(
2051                 node_in_chain->name());
2052           }
2053         }
2054         // We now have
2055         //    reshape_to_bypass -> tail -> ... -> node
2056         // where tail maybe equal to node.
2057         TF_RETURN_IF_ERROR(
2058             UpdateConsumers(reshape_to_bypass, reshape_to_bypass->input(0)));
2059         ForwardControlDependencies(tail, {reshape_to_bypass});
2060         // Change the bypassed reshape to NoOp.
2061         ReplaceWithNoOp(reshape_to_bypass, ctx());
2062         *simplified_node_name = node->name();
2063         return Status::OK();
2064       }
2065     }
2066 
2067     return Status::OK();
2068   }
2069 
2070  private:
2071   // Returns whether `reshape` is an identity op.
InputMatchesTargetShape(const NodeDef & reshape)2072   bool InputMatchesTargetShape(const NodeDef& reshape) {
2073     const OpInfo::TensorProperties* reshape_props;
2074     const OpInfo::TensorProperties* input_props;
2075     if (!GetTensorProperties(reshape.name(), &reshape_props).ok() ||
2076         !GetTensorProperties(reshape.input(0), &input_props).ok()) {
2077       return false;
2078     }
2079 
2080     return ShapesSymbolicallyEqual(input_props->shape(),
2081                                    reshape_props->shape());
2082   }
2083 };
2084 
2085 // Reorder casting and value-preserving ops if beneficial.
2086 //
2087 // Original motivation: A common pattern after the layout optimizer is
2088 // casting an uint8 NHWC image to float before transposing it to NCHW. It
2089 // is beneficial to reorder the cast and the transpose to make the transpose
2090 // process smaller amount of data. More generally, this optimization converts
2091 //   Op(Cast(tensor, dst_type))
2092 // to
2093 //   Cast(Op(tensor), dst_type)
2094 // when sizeof(tensor.type) < sizeof(dst_type), and Op is any value-preserving
2095 // Op, i.e. an op that only reorders the elements in its first input. Similarly,
2096 // this optimization converts
2097 //   Cast(Op(tensor), dst_type)
2098 // to
2099 //   Op(Cast(tensor, dst_type))
2100 // when sizeof(tensor.type) > sizeof(dst_type)
2101 //
2102 class ReorderCastLikeAndValuePreserving : public ArithmeticOptimizerStage {
2103  public:
ReorderCastLikeAndValuePreserving(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2104   explicit ReorderCastLikeAndValuePreserving(
2105       const GraphOptimizerContext& ctx,
2106       const ArithmeticOptimizerContext& ctx_ext)
2107       : ArithmeticOptimizerStage("ReorderCastLikeAndValuePreserving", ctx,
2108                                  ctx_ext) {}
2109   ~ReorderCastLikeAndValuePreserving() override = default;
2110 
IsSupported(const NodeDef * node) const2111   bool IsSupported(const NodeDef* node) const override {
2112     return (IsValuePreserving(*node) || IsCastLike(*node)) &&
2113            !IsCheckNumerics(*node) && NodeIsOnCpuOrGpu(node) &&
2114            !IsControlFlow(*node) && !IsInPreserveSet(*node);
2115   }
2116 
TrySimplify(NodeDef * consumer,string * simplified_node_name)2117   Status TrySimplify(NodeDef* consumer, string* simplified_node_name) override {
2118     NodeDef* producer;
2119 
2120     if (consumer->input_size() < 1) {
2121       return errors::FailedPrecondition("Node ", simplified_node_name,
2122                                         " lacks inputs");
2123     }
2124 
2125     TF_RETURN_IF_ERROR(GetInputNode(consumer->input(0), &producer));
2126     const bool producer_is_cast = IsCastLike(*producer);
2127     const bool can_optimize =
2128         !IsCheckNumerics(*producer) &&
2129         ((producer_is_cast && IsValuePreserving(*consumer)) ||
2130          (IsValuePreserving(*producer) && IsCastLike(*consumer)));
2131     if (!can_optimize || IsControlFlow(*producer) ||
2132         IsInPreserveSet(*producer) ||
2133         producer->device() != consumer->device()) {
2134       return Status::OK();
2135     }
2136 
2137     const NodeDef* cast_like_node = producer_is_cast ? producer : consumer;
2138     const OpDef* cast_like_op_def = nullptr;
2139     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(cast_like_node->op(),
2140                                                          &cast_like_op_def));
2141     DataType cast_src_type;
2142     TF_RETURN_IF_ERROR(InputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2143                                         &cast_src_type));
2144     DataType cast_dst_type;
2145     TF_RETURN_IF_ERROR(OutputTypeForNode(*cast_like_node, *cast_like_op_def, 0,
2146                                          &cast_dst_type));
2147     if (!IsFixedSizeType(cast_src_type) || !IsFixedSizeType(cast_dst_type)) {
2148       return Status::OK();
2149     } else if (producer_is_cast &&
2150                DataTypeSize(cast_dst_type) <= DataTypeSize(cast_src_type)) {
2151       return Status::OK();
2152     } else if (!producer_is_cast &&
2153                DataTypeSize(cast_dst_type) >= DataTypeSize(cast_src_type)) {
2154       return Status::OK();
2155     }
2156 
2157     // Check that nodes were not already optimized.
2158     const string optimized_producer_name = OptimizedNodeName(
2159         ParseNodeScopeAndName(producer->name()), DataTypeString(cast_dst_type));
2160     const string optimized_consumer_name = OptimizedNodeName(
2161         ParseNodeScopeAndName(consumer->name()), DataTypeString(cast_src_type));
2162     const bool is_already_optimized =
2163         ctx().node_map->NodeExists(optimized_consumer_name) ||
2164         ctx().node_map->NodeExists(optimized_producer_name);
2165     if (is_already_optimized) {
2166       return Status::OK();
2167     }
2168 
2169     // Add copies of consumer and producer in reverse order.
2170     NodeDef* input;
2171     TF_RETURN_IF_ERROR(GetInputNode(producer->input(0), &input));
2172     // Create new producer node.
2173     NodeDef* new_producer = AddCopyNode(optimized_consumer_name, consumer);
2174     new_producer->set_input(0, producer->input(0));
2175     ctx().node_map->AddOutput(input->name(), new_producer->name());
2176 
2177     // Create new consumer node.
2178     NodeDef* new_consumer = AddCopyNode(optimized_producer_name, producer);
2179     new_consumer->set_input(0, new_producer->name());
2180 
2181     NodeDef* new_value_preserving =
2182         producer_is_cast ? new_producer : new_consumer;
2183     const DataType new_input_type =
2184         producer_is_cast ? cast_src_type : cast_dst_type;
2185     // Update the input type of the value-preserving node. The input and
2186     // output types of the cast-like nodes remain the same.
2187     TF_RETURN_IF_ERROR(SetInputType(new_input_type, new_value_preserving));
2188     // Make sure there is a kernel registered for the value preserving op
2189     // with the new input type.
2190     TF_RETURN_IF_ERROR(IsKernelRegisteredForNode(*new_value_preserving));
2191     ctx().node_map->AddOutput(new_producer->name(), new_consumer->name());
2192 
2193     AddToOptimizationQueue(new_producer);
2194     *simplified_node_name = new_consumer->name();
2195 
2196     return Status::OK();
2197   }
2198 
2199  private:
2200   // Sets the type of the first input to dtype.
SetInputType(DataType dtype,NodeDef * node)2201   Status SetInputType(DataType dtype, NodeDef* node) {
2202     const OpDef* op_def = nullptr;
2203     TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node->op(), &op_def));
2204     const OpDef::ArgDef& input_arg = op_def->input_arg(0);
2205     const string& type_attr_name = input_arg.type_attr();
2206     if (type_attr_name.empty()) {
2207       if (input_arg.type() == DT_INVALID || input_arg.type() != dtype) {
2208         return errors::InvalidArgument("Could not set input type of ",
2209                                        node->op(), " op to ",
2210                                        DataTypeString(dtype));
2211       } else {
2212         // Op has fixed input type that already matches dtype.
2213         return Status::OK();
2214       }
2215     }
2216     SetDataTypeToAttr(dtype, type_attr_name, node);
2217     return Status::OK();
2218   }
2219   // This optimization can be dangerous on devices other than CPU and
2220   // GPU. The transpose might not be implemented for image.type, or
2221   // might be slower with image.type than with cast_dst_type.
NodeIsOnCpuOrGpu(const NodeDef * node) const2222   bool NodeIsOnCpuOrGpu(const NodeDef* node) const {
2223     using absl::StrContains;
2224 
2225     string task;
2226     string device;
2227 
2228     return DeviceNameUtils::SplitDeviceName(node->device(), &task, &device) &&
2229            (StrContains(device, DEVICE_CPU) || StrContains(device, DEVICE_GPU));
2230   }
2231 
IsFixedSizeType(DataType dtype)2232   bool IsFixedSizeType(DataType dtype) {
2233     return dtype != DT_STRING && dtype != DT_VARIANT && dtype != DT_RESOURCE &&
2234            !kQuantizedTypes.Contains(dtype);
2235   }
2236 };
2237 
2238 // Fold a multiply of a scalar into the following convolution. This folding
2239 // can jump across nodes that merely reorders data (such as reshape and
2240 // transpose). For example, we can optimize
2241 //
2242 //
2243 //         Conv2D                             Conv2D
2244 //        /      \                           /      \
2245 //    Transpose  weights*       ->     Transpose    Mul
2246 //       |                                |        /   \
2247 //      Mul                               |    weights  scale
2248 //     /   \                              |
2249 //   input  scale**                     input
2250 //
2251 //  *) weights must be a const
2252 // **) scale must be a const scalar
2253 //
2254 // When `weights` and `scale` are constant, `Mul` in the optimized graph can be
2255 // constant-folded, also weights tend to be smaller than the activations.
2256 //
2257 // TODO(jingyue): Fold scalar multiplies to Conv?DBackpropFilter and
2258 // Conv?DBackpropInput.
2259 class FoldMultiplyIntoConv : public ArithmeticOptimizerStage {
2260  public:
FoldMultiplyIntoConv(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2261   explicit FoldMultiplyIntoConv(const GraphOptimizerContext& ctx,
2262                                 const ArithmeticOptimizerContext& ctx_ext)
2263       : ArithmeticOptimizerStage("FoldMultiplyIntoConv", ctx, ctx_ext) {}
2264   ~FoldMultiplyIntoConv() override = default;
2265 
IsSupported(const NodeDef * node) const2266   bool IsSupported(const NodeDef* node) const override {
2267     return IsConv2D(*node) || IsConv3D(*node);
2268   }
2269 
TrySimplify(NodeDef * node,string * simplified_node_name)2270   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2271 #define TF_RETURN_IF_TRUE(...) \
2272   if ((__VA_ARGS__)) return Status::OK()
2273 
2274     NodeDef* conv = node;
2275 
2276     NodeDef* weights;
2277     TF_RETURN_IF_ERROR(GetInputNode(conv->input(1), &weights));
2278 
2279     // Fold the multiply to conv only when the weights are constant, so the
2280     // multiply can be constant-folded.
2281     //
2282     // TODO(jingyue): When the weights aren't constant, this should also help
2283     // performance a bit and memory usage a lot, since the weights tend to be
2284     // smaller than the activations.
2285     TF_RETURN_IF_TRUE(!IsConstant(*weights));
2286 
2287     // Verify that this node was not already optimized.
2288     const string scaled_weights_node_name =
2289         OptimizedNodeName(ParseNodeScopeAndName(weights->name()),
2290                           strings::StrCat("scaled", "_", conv->name()));
2291 
2292     TF_RETURN_IF_TRUE(ctx().node_map->NodeExists(scaled_weights_node_name));
2293 
2294     // Find the tail of value preserving chain entering the Conv node.
2295     NodeDef* tail = GetTailOfValuePreservingChain(*conv, *ctx().node_map,
2296                                                   *ctx().nodes_to_preserve);
2297 
2298     NodeDef* source;
2299     TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &source));
2300 
2301     // Check that value preserving chain is the only consumer of the Mul output.
2302     TF_RETURN_IF_TRUE(!IsAnyMul(*source));
2303     TF_RETURN_IF_TRUE(NumNonControlOutputs(*source, *ctx().node_map) != 1);
2304     // And that Mul is not in the preserve set.
2305     TF_RETURN_IF_TRUE(IsInPreserveSet(*source));
2306 
2307     const NodeDef* mul = source;
2308     int input_idx = 0;
2309     int scale_idx = 1;
2310     NodeDef* scale;  // scalar multiplier for the input tensor
2311     NodeDef* input;
2312     TF_RETURN_IF_ERROR(GetInputNode(mul->input(scale_idx), &scale));
2313     TF_RETURN_IF_ERROR(GetInputNode(mul->input(input_idx), &input));
2314     if (!IsConstant(*scale) && IsConstant(*input)) {
2315       VLOG(3) << "Swapped inputs to mul";
2316       std::swap(scale_idx, input_idx);
2317       std::swap(scale, input);
2318     }
2319     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2320 
2321     // Check that one of the inputs to mul is a constant scalar.
2322     const TensorProto& scale_tensor = scale->attr().at("value").tensor();
2323     bool scale_is_a_scalar = scale_tensor.has_tensor_shape() &&
2324                              scale_tensor.tensor_shape().dim_size() == 0;
2325     TF_RETURN_IF_TRUE(!scale_is_a_scalar);
2326 
2327     // Check that 'scale * weight' can be const folded.
2328     TF_RETURN_IF_TRUE(!IsConstant(*scale));
2329     TF_RETURN_IF_ERROR(CheckAttrsExist(*scale, {"dtype"}));
2330     TF_RETURN_IF_ERROR(CheckAttrExists(*weights, "dtype"));
2331     TF_RETURN_IF_TRUE(scale->attr().at("dtype").type() !=
2332                       weights->attr().at("dtype").type());
2333 
2334     // At this point all preconditions are met, and we safely do the rewrite.
2335     VLOG(3) << "Fold multiply into conv: conv=" << conv->name()
2336             << " mul=" << mul->name() << " weights=" << weights->name();
2337 
2338     // Create new node `scaled_weights`.
2339     NodeDef* scaled_weights = AddEmptyNode(scaled_weights_node_name);
2340     scaled_weights->set_op(source->op());
2341     scaled_weights->set_device(weights->device());
2342     (*scaled_weights->mutable_attr())["T"] = weights->attr().at("dtype");
2343     AddToOptimizationQueue(scaled_weights);
2344 
2345     // Link in its inputs.
2346     scaled_weights->add_input(conv->input(1));
2347     ctx().node_map->AddOutput(weights->name(), scaled_weights->name());
2348     scaled_weights->add_input(mul->input(scale_idx));
2349     ctx().node_map->AddOutput(scale->name(), scaled_weights->name());
2350     ForwardControlDependencies(scaled_weights, {source});
2351 
2352     // Update `conv`'s weights to `scaled_weights`.
2353     conv->set_input(1, scaled_weights->name());
2354     ctx().node_map->UpdateInput(conv->name(), weights->name(),
2355                                 scaled_weights->name());
2356     AddToOptimizationQueue(conv);
2357 
2358     // Update `tail` node to bypass `mul` because it's folded to the weights.
2359     tail->set_input(0, mul->input(input_idx));
2360     ctx().node_map->UpdateInput(tail->name(), mul->name(), input->name());
2361     AddToOptimizationQueue(tail);
2362     *simplified_node_name = conv->name();
2363 
2364     return Status::OK();
2365 #undef TF_RETURN_IF_TRUE
2366   }
2367 };
2368 
2369 // Fold Transpose into matrix multiplication.
2370 class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
2371  public:
FoldTransposeIntoMatMul(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2372   explicit FoldTransposeIntoMatMul(const GraphOptimizerContext& ctx,
2373                                    const ArithmeticOptimizerContext& ctx_ext)
2374       : ArithmeticOptimizerStage("FoldTransposeIntoMatMul", ctx, ctx_ext) {}
2375   ~FoldTransposeIntoMatMul() override = default;
2376 
IsSupported(const NodeDef * node) const2377   bool IsSupported(const NodeDef* node) const override {
2378     return IsAnyMatMul(*node) && !IsInPreserveSet(*node);
2379   }
2380 
TrySimplify(NodeDef * node,string * simplified_node_name)2381   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2382     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2383     const string optimized_node_name = OptimizedNodeName(matmul);
2384     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2385 
2386     NodeDef* a;
2387     NodeDef* b;
2388     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &a));
2389     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &b));
2390 
2391     bool is_complex = false;
2392     if (node->op() != "SparseMatMul") {
2393       const DataType type = GetDataTypeFromAttr(*node, "T");
2394       is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2395     }
2396 
2397     const std::set<string> foldable_transpose_ops =
2398         !is_complex
2399             ? std::set<string>{"ConjugateTranspose", "Transpose"}
2400             : (IsAnyBatchMatMul(*node) ? std::set<string>{"ConjugateTranspose"}
2401                                        : std::set<string>{"Transpose"});
2402 
2403     const bool a_is_foldable = foldable_transpose_ops.count(a->op()) > 0 &&
2404                                IsInnerMatrixTransposeNode(*a, ctx().node_map);
2405     const bool b_is_foldable = foldable_transpose_ops.count(b->op()) > 0 &&
2406                                IsInnerMatrixTransposeNode(*b, ctx().node_map);
2407     if (!a_is_foldable && !b_is_foldable) return Status::OK();
2408 
2409     NodeDef* new_op = AddCopyNode(optimized_node_name, node);
2410 
2411     if (a_is_foldable) {
2412       const string attr_a = IsAnyBatchMatMul(*node) ? "adj_x" : "transpose_a";
2413       FlipBooleanAttr(attr_a, new_op);
2414       new_op->set_input(0, a->input(0));
2415       ctx().node_map->UpdateInput(new_op->name(), a->name(), a->input(0));
2416     } else {
2417       ctx().node_map->UpdateOutput(a->name(), node->name(), new_op->name());
2418     }
2419 
2420     if (b_is_foldable) {
2421       const string attr_b = IsAnyBatchMatMul(*node) ? "adj_y" : "transpose_b";
2422       FlipBooleanAttr(attr_b, new_op);
2423       new_op->set_input(1, b->input(0));
2424       ctx().node_map->UpdateInput(new_op->name(), b->name(), b->input(0));
2425     } else {
2426       ctx().node_map->UpdateOutput(b->name(), node->name(), new_op->name());
2427     }
2428 
2429     std::vector<const NodeDef*> deps_to_forward = {node};
2430     if (a_is_foldable) deps_to_forward.push_back(a);
2431     if (b_is_foldable) deps_to_forward.push_back(b);
2432     ForwardControlDependencies(new_op, deps_to_forward);
2433     *simplified_node_name = new_op->name();
2434 
2435     return Status::OK();
2436   }
2437 
2438  private:
FlipBooleanAttr(const string & attr_name,NodeDef * node)2439   void FlipBooleanAttr(const string& attr_name, NodeDef* node) {
2440     const bool old_value =
2441         !node->attr().count(attr_name) ? false : node->attr().at(attr_name).b();
2442     (*node->mutable_attr())[attr_name].set_b(!old_value);
2443   }
2444 
2445   template <typename T>
IsInnerMatrixTranspose(const std::vector<T> & perm)2446   bool IsInnerMatrixTranspose(const std::vector<T>& perm) {
2447     const T n = perm.size();
2448     if (n < 2) {
2449       return false;
2450     }
2451     for (T i = 0; i < n - 2; ++i) {
2452       if (perm[i] != i) {
2453         return false;
2454       }
2455     }
2456     return perm[n - 1] == n - 2 && perm[n - 2] == n - 1;
2457   }
2458 
IsInnerMatrixTransposeNode(const NodeDef & transpose_node,const NodeMap * node_map)2459   bool IsInnerMatrixTransposeNode(const NodeDef& transpose_node,
2460                                   const NodeMap* node_map) {
2461     if (transpose_node.op() != "Transpose" &&
2462         transpose_node.op() != "ConjugateTranspose") {
2463       return false;
2464     }
2465     const NodeDef* perm_node = node_map->GetNode(transpose_node.input(1));
2466     std::vector<int> perm32;
2467     if (ValuesFromConstNode(*perm_node, &perm32)) {
2468       return IsInnerMatrixTranspose(perm32);
2469     }
2470     std::vector<int64> perm64;
2471     if (ValuesFromConstNode(*perm_node, &perm64)) {
2472       return IsInnerMatrixTranspose(perm64);
2473     }
2474     return false;
2475   }
2476 };
2477 
2478 class FoldConjugateIntoTranspose : public ArithmeticOptimizerStage {
2479  public:
FoldConjugateIntoTranspose(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2480   explicit FoldConjugateIntoTranspose(const GraphOptimizerContext& ctx,
2481                                       const ArithmeticOptimizerContext& ctx_ext)
2482       : ArithmeticOptimizerStage("FoldConjugateIntoTranspose", ctx, ctx_ext) {}
2483   ~FoldConjugateIntoTranspose() override = default;
2484 
IsSupported(const NodeDef * node) const2485   bool IsSupported(const NodeDef* node) const override {
2486     return IsConj(*node) || IsTranspose(*node) || IsConjugateTranspose(*node);
2487   }
2488 
TrySimplify(NodeDef * node,string * simplified_node_name)2489   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2490     const NodeScopeAndName matmul = ParseNodeScopeAndName(node->name());
2491     const string optimized_node_name = OptimizedNodeName(matmul);
2492     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2493 
2494     NodeDef* input;
2495     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2496 
2497     const NodeDef* transpose_op = node->op() == "Conj" ? input : node;
2498     const NodeDef* conj_op = node->op() == "Conj" ? node : input;
2499 
2500     if ((IsTranspose(*transpose_op) || IsConjugateTranspose(*transpose_op)) &&
2501         IsConj(*conj_op)) {
2502       NodeDef* new_op = AddCopyNode(optimized_node_name, transpose_op);
2503 
2504       // Flip the type of transpose op to absorb the conjugation.
2505       new_op->set_op(transpose_op->op() == "Transpose" ? "ConjugateTranspose"
2506                                                        : "Transpose");
2507       new_op->set_input(0, input->input(0));
2508       ctx().node_map->UpdateInput(new_op->name(), node->name(),
2509                                   input->input(0));
2510       ForwardControlDependencies(new_op, {node, input});
2511       *simplified_node_name = new_op->name();
2512     }
2513 
2514     return Status::OK();
2515   }
2516 };
2517 
2518 // Replace Mul node with identical inputs with a Square.
2519 class ReplaceMulWithSquare : public ArithmeticOptimizerStage {
2520  public:
ReplaceMulWithSquare(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2521   explicit ReplaceMulWithSquare(const GraphOptimizerContext& ctx,
2522                                 const ArithmeticOptimizerContext& ctx_ext)
2523       : ArithmeticOptimizerStage("ReplaceMulWithSquare", ctx, ctx_ext) {}
2524   ~ReplaceMulWithSquare() override = default;
2525 
IsSupported(const NodeDef * node) const2526   bool IsSupported(const NodeDef* node) const override {
2527     if (!node || node->input_size() < 2) {
2528       // Invalid node
2529       return false;
2530     }
2531 
2532     return IsAnyMul(*node) && node->input(0) == node->input(1);
2533   }
2534 
TrySimplify(NodeDef * node,string * simplified_node_name)2535   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2536     const NodeScopeAndName mul = ParseNodeScopeAndName(node->name());
2537     const string optimized_node_name = OptimizedNodeName(mul);
2538     if (ctx().node_map->NodeExists(optimized_node_name)) return Status::OK();
2539 
2540     const DataType type = GetDataTypeFromAttr(*node, "T");
2541     bool is_complex = (type == DT_COMPLEX64) || (type == DT_COMPLEX128);
2542 
2543     if (!is_complex || NodeIsOnCpu(*node)) {
2544       NodeDef* new_square_node = AddCopyNode(optimized_node_name, node);
2545       new_square_node->set_op("Square");
2546       for (int i = 1; i < new_square_node->input_size(); ++i) {
2547         new_square_node->set_input(i - 1, new_square_node->input(i));
2548       }
2549       new_square_node->mutable_input()->RemoveLast();
2550       for (const string& input : new_square_node->input()) {
2551         ctx().node_map->AddOutput(NodeName(input), new_square_node->name());
2552       }
2553       *simplified_node_name = new_square_node->name();
2554     }
2555 
2556     return Status::OK();
2557   }
2558 };
2559 
2560 // Replace a combination of Mul with broadcasting by Tile. E.g. replace
2561 //
2562 // input(1x22x1x48x1x64) -> Mul (1x22x2x48x2x64) -> output
2563 // Ones (1x22x2x48x2x64) -^
2564 //
2565 // with
2566 //
2567 // input -> Tile(1x22x2x48x2x64) -> output
2568 class ReplaceMulWithBroadcastByTile : public ArithmeticOptimizerStage {
2569  public:
ReplaceMulWithBroadcastByTile(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2570   explicit ReplaceMulWithBroadcastByTile(
2571       const GraphOptimizerContext& ctx,
2572       const ArithmeticOptimizerContext& ctx_ext)
2573       : ArithmeticOptimizerStage("ReplaceMulWithBroadcastByTile", ctx,
2574                                  ctx_ext) {}
2575   ~ReplaceMulWithBroadcastByTile() override = default;
2576 
IsSupported(const NodeDef * node) const2577   bool IsSupported(const NodeDef* node) const override {
2578     return IsMul(*node) && !IsInPreserveSet(*node);
2579   }
2580 
TrySimplify(NodeDef * node,string * simplified_node_name)2581   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2582     NodeDef *input, *ones;
2583     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
2584     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
2585     if (IsInPreserveSet(*node) || IsInPreserveSet(*input) ||
2586         IsInPreserveSet(*ones)) {
2587       return Status::OK();
2588     }
2589 
2590     // TODO(kkiningh): Generalize using IsOnes from constant_folding.cc
2591     if (IsConstant(*input) || !IsOnes(*ones)) return Status::OK();
2592 
2593     // Avoid optimizing the same node twice
2594     const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2595     const string tile_node_name = OptimizedNodeName(scope_and_name, "Tile");
2596     const string const_node_name = OptimizedNodeName(scope_and_name, "Const");
2597     if (ctx().node_map->NodeExists(tile_node_name) ||
2598         ctx().node_map->NodeExists(const_node_name)) {
2599       return Status::OK();
2600     }
2601 
2602     const std::vector<OpInfo::TensorProperties>& props =
2603         ctx().graph_properties->GetInputProperties(node->name());
2604     if (props.size() != 2) return Status::OK();
2605 
2606     // Ignore ops where the shape doesn't change
2607     const TensorShapeProto& input_shape = props[0].shape();
2608     const TensorShapeProto& ones_shape = props[1].shape();
2609     TensorShapeProto output_shape;
2610     if (!ShapeAfterBroadcast(input_shape, ones_shape, &output_shape)) {
2611       return Status::OK();
2612     }
2613     if (ShapesSymbolicallyEqual(input_shape, output_shape)) {
2614       return Status::OK();
2615     }
2616 
2617     // All inputs must have same input/output dimensions
2618     if (input_shape.dim_size() != output_shape.dim_size() ||
2619         ones_shape.dim_size() != output_shape.dim_size())
2620       return Status::OK();
2621 
2622     // At this point all preconditions are met. Can proceed with rewrite.
2623     VLOG(3) << "Simplify multiply with all ones input: node=" << node->name()
2624             << "@" << output_shape << " ones=" << ones->name() << "@"
2625             << ones_shape << " input=" << input->name() << "@" << input_shape;
2626 
2627     // 1. Create constant node with correct tile multiples
2628     Tensor multiples(DT_INT32, TensorShape({output_shape.dim_size()}));
2629     for (int i = 0; i < output_shape.dim_size(); ++i) {
2630       int64_t size = output_shape.dim(i).size() / input_shape.dim(i).size();
2631       if (TF_PREDICT_FALSE(size >= INT_MAX)) {
2632         return Status(error::OUT_OF_RANGE, "int32 overflow");
2633       }
2634       multiples.flat<int32>()(i) = static_cast<int32>(size);
2635     }
2636 
2637     NodeDef* const_node = AddEmptyNode(const_node_name);
2638     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2639         const_node->name(), TensorValue(&multiples), const_node));
2640     const_node->set_device(node->device());
2641     ForwardControlDependencies(const_node, {ones});
2642     AddToOptimizationQueue(const_node);
2643 
2644     // 2. Replace multiply node with Tile(Const, input);
2645     const DataType type = GetDataTypeFromAttr(*node, "T");
2646     NodeDef* tile_node = AddEmptyNode(tile_node_name);
2647     tile_node->set_op("Tile");
2648     tile_node->set_device(node->device());
2649     SetDataTypeToAttr(type, "T", tile_node);
2650     SetDataTypeToAttr(DT_INT32, "Tmultiples", tile_node);
2651     tile_node->add_input(input->name());
2652     tile_node->add_input(const_node->name());
2653 
2654     ForwardControlDependencies(tile_node, {node});
2655     *simplified_node_name = tile_node->name();
2656 
2657     return Status::OK();
2658   }
2659 
2660  protected:
IsOnes(const NodeDef & node) const2661   bool IsOnes(const NodeDef& node) const {
2662     if (!IsReallyConstant(node)) return false;
2663     if (node.attr().at("dtype").type() != DT_FLOAT) return false;
2664 
2665     Tensor tensor;
2666     if (!tensor.FromProto(node.attr().at("value").tensor())) {
2667       return false;
2668     }
2669 
2670     auto values = tensor.flat<float>();
2671     for (int i = 0; i < tensor.NumElements(); ++i) {
2672       if (values(i) != 1.0f) {
2673         return false;
2674       }
2675     }
2676 
2677     return true;
2678   }
2679 };
2680 
2681 // Image upsampling often produces an unnecessary reshape that is difficult to
2682 // eliminate in other stages. This stage reduces the number of dimensions
2683 // involved allowing the reshape to be removed.
2684 //
2685 // For example, given
2686 //   B,W,H,C -> Reshape(B,W,1,H,1,C) -> Tile(1,1,2,1,2,1) -> Reshape(B,2W,2H,C)
2687 // this pass converts the sequence to
2688 //   B,W,H,C -> Reshape(B,W,H,C) -> Tile(1,1,2,2) -> Reshape(B,2W,2H,C)
2689 //
2690 // The first reshape is now redundant and can be removed in a later pass.
2691 //
2692 // Note: This only optimizes the simple (but extremely common) case of 2D
2693 // upsampling.
2694 //
2695 // TODO(kkiningh): Generalize to more complex upsampling patterns.
2696 class ReduceUpsamplingDims : public ArithmeticOptimizerStage {
2697  public:
ReduceUpsamplingDims(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2698   explicit ReduceUpsamplingDims(const GraphOptimizerContext& ctx,
2699                                 const ArithmeticOptimizerContext& ctx_ext)
2700       : ArithmeticOptimizerStage("ReduceUpsamplingDims", ctx, ctx_ext) {}
2701   ~ReduceUpsamplingDims() override = default;
2702 
IsSupported(const NodeDef * node) const2703   bool IsSupported(const NodeDef* node) const override {
2704     return IsReshape(*node) && !IsInPreserveSet(*node);
2705   }
2706 
TrySimplify(NodeDef * node,string * simplified_node_name)2707   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2708     NodeDef* tile;
2709     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &tile));
2710     if (!IsTile(*tile) || IsInPreserveSet(*tile)) {
2711       return Status::OK();
2712     }
2713 
2714     if (NumNonControlOutputs(*tile, *ctx().node_map) != 1) {
2715       // Optimization is only worthwile when there is a single output from Tile.
2716       // Otherwise, we need to insert addtional Reshape ops that can't be easily
2717       // removed.
2718       return Status::OK();
2719     }
2720 
2721     NodeDef* reshape;
2722     TF_RETURN_IF_ERROR(GetInputNode(tile->input(0), &reshape));
2723     if (!IsReshape(*reshape) || IsInPreserveSet(*reshape)) {
2724       return Status::OK();
2725     }
2726 
2727     NodeDef* multiples;
2728     TF_RETURN_IF_ERROR(GetInputNode(tile->input(1), &multiples));
2729 
2730     NodeDef* shape;
2731     TF_RETURN_IF_ERROR(GetInputNode(reshape->input(1), &shape));
2732 
2733     // Avoid optimizing the same nodes twice
2734     const NodeScopeAndName scope_and_name = ParseNodeScopeAndName(node->name());
2735     const string new_reshape_name =
2736         OptimizedNodeName(scope_and_name, "Reshape");
2737     const string new_tile_name = OptimizedNodeName(scope_and_name, "Tile");
2738     const string new_multiples_name =
2739         OptimizedNodeName(scope_and_name, "Multiples");
2740     const string new_shape_name = OptimizedNodeName(scope_and_name, "Shape");
2741     if (ctx().node_map->NodeExists(new_reshape_name) ||
2742         ctx().node_map->NodeExists(new_tile_name) ||
2743         ctx().node_map->NodeExists(new_shape_name) ||
2744         ctx().node_map->NodeExists(new_multiples_name)) {
2745       return Status::OK();
2746     }
2747 
2748     // Compuate updated multiples/shape values.
2749     AttrValue new_multiples_attr;
2750     if (!CreateUpdatedMultiplesProto(multiples,
2751                                      new_multiples_attr.mutable_tensor())) {
2752       return Status::OK();
2753     }
2754     AttrValue new_shape_attr;
2755     if (!CreateUpdatedShapeProto(shape, new_shape_attr.mutable_tensor())) {
2756       return Status::OK();
2757     }
2758 
2759     // At this point the graph is validated and can be updated
2760     // Note: We can assume shape/multiples are DT_INT32 ony at this point since
2761     // they're checked in CreateUpdated*Proto()
2762 
2763     // 1. Create the constant nodes used by the new Reshape/Tile nodes
2764     NodeDef* new_multiples = AddEmptyNode(new_multiples_name);
2765     new_multiples->set_op("Const");
2766     SetDataTypeToAttr(DT_INT32, "dtype", new_multiples);
2767     new_multiples->mutable_attr()->insert({"value", new_multiples_attr});
2768     new_multiples->set_device(multiples->device());
2769 
2770     NodeDef* new_shape = AddEmptyNode(new_shape_name);
2771     new_shape->set_op("Const");
2772     SetDataTypeToAttr(DT_INT32, "dtype", new_shape);
2773     new_shape->mutable_attr()->insert({"value", new_shape_attr});
2774     new_shape->set_device(shape->device());
2775 
2776     // 2. Create the new Reshape/Tile nodes
2777     NodeDef* new_reshape = AddEmptyNode(new_reshape_name);
2778     CopyReshapeWithInput(reshape, new_reshape, /*input=*/reshape->input(0),
2779                          /*shape=*/new_shape->name());
2780     NodeDef* new_tile = AddEmptyNode(new_tile_name);
2781     CopyTileWithInput(tile, new_tile, /*input=*/new_reshape->name(),
2782                       /*multiples=*/new_multiples->name());
2783 
2784     // 3. Update consumer of original Tile node and add control
2785     node->set_input(0, new_tile->name());
2786     ctx().node_map->UpdateInput(node->name(), tile->name(), new_tile->name());
2787 
2788     ForwardControlDependencies(new_tile, {tile});
2789     ForwardControlDependencies(new_multiples, {multiples});
2790     ForwardControlDependencies(new_reshape, {reshape});
2791     ForwardControlDependencies(new_shape, {shape});
2792 
2793     *simplified_node_name = node->name();
2794     return Status::OK();
2795   }
2796 
2797  private:
CreateUpdatedMultiplesProto(const NodeDef * node,TensorProto * proto)2798   bool CreateUpdatedMultiplesProto(const NodeDef* node, TensorProto* proto) {
2799     Tensor multiples;
2800     if (!GetTensorFromConstNode(node->name(), &multiples)) {
2801       return false;
2802     }
2803 
2804     // Dimensions should be [X, Y, N, 1, M, 1]
2805     if (multiples.dtype() != DT_INT32 || multiples.NumElements() != 6) {
2806       return false;
2807     }
2808 
2809     const auto& multiples_values = multiples.flat<int32>();
2810     if (multiples_values(3) != 1 || multiples_values(5) != 1) {
2811       return false;
2812     }
2813 
2814     // Convert to [X, Y, N, M]
2815     Tensor new_multiples(DT_INT32, {4});
2816     new_multiples.flat<int32>()(0) = multiples_values(0);
2817     new_multiples.flat<int32>()(1) = multiples_values(1);
2818     new_multiples.flat<int32>()(2) = multiples_values(2);
2819     new_multiples.flat<int32>()(3) = multiples_values(4);
2820 
2821     new_multiples.AsProtoTensorContent(proto);
2822     return true;
2823   }
2824 
CreateUpdatedShapeProto(const NodeDef * node,TensorProto * proto)2825   bool CreateUpdatedShapeProto(const NodeDef* node, TensorProto* proto) {
2826     Tensor shape;
2827     if (!GetTensorFromConstNode(node->name(), &shape)) {
2828       return false;
2829     }
2830 
2831     // Dimensions should be [B, W, 1, H, 1, C]
2832     if (shape.dtype() != DT_INT32 || shape.NumElements() != 6) {
2833       return false;
2834     }
2835 
2836     const auto& shape_values = shape.flat<int32>();
2837     if (shape_values(2) != 1 || shape_values(4) != 1) {
2838       return false;
2839     }
2840 
2841     // Convert to [B, W, H, C]
2842     Tensor new_shape(DT_INT32, {4});
2843     new_shape.flat<int32>()(0) = shape_values(0);
2844     new_shape.flat<int32>()(1) = shape_values(1);
2845     new_shape.flat<int32>()(2) = shape_values(3);
2846     new_shape.flat<int32>()(3) = shape_values(5);
2847 
2848     new_shape.AsProtoTensorContent(proto);
2849     return true;
2850   }
2851 
CopyReshapeWithInput(const NodeDef * reshape,NodeDef * new_reshape,const string & input,const string & shape)2852   void CopyReshapeWithInput(const NodeDef* reshape, NodeDef* new_reshape,
2853                             const string& input, const string& shape) {
2854     new_reshape->set_op("Reshape");
2855     new_reshape->set_device(reshape->device());
2856     SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "T"), "T", new_reshape);
2857     SetDataTypeToAttr(GetDataTypeFromAttr(*reshape, "Tshape"), "Tshape",
2858                       new_reshape);
2859 
2860     new_reshape->add_input(input);
2861     ctx().node_map->AddOutput(NodeName(input), new_reshape->name());
2862     new_reshape->add_input(shape);
2863     ctx().node_map->AddOutput(NodeName(shape), new_reshape->name());
2864 
2865     AddToOptimizationQueue(new_reshape);
2866   }
2867 
CopyTileWithInput(const NodeDef * tile,NodeDef * new_tile,const string & input,const string & multiples)2868   void CopyTileWithInput(const NodeDef* tile, NodeDef* new_tile,
2869                          const string& input, const string& multiples) {
2870     new_tile->set_op("Tile");
2871     new_tile->set_device(tile->device());
2872     SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "T"), "T", new_tile);
2873     SetDataTypeToAttr(GetDataTypeFromAttr(*tile, "Tmultiples"), "Tmultiples",
2874                       new_tile);
2875 
2876     new_tile->add_input(input);
2877     ctx().node_map->AddOutput(NodeName(input), new_tile->name());
2878     new_tile->add_input(multiples);
2879     ctx().node_map->AddOutput(NodeName(multiples), new_tile->name());
2880 
2881     AddToOptimizationQueue(new_tile);
2882   }
2883 };
2884 
2885 // Replace a sequence of Pack nodes with identical inputs with Tile
2886 // For example, given a Tensor X with shape (I,J,K)
2887 // Let P(x, n) = Pack([x, x], axis=n)
2888 //
2889 // P(P(X, 2), 1)
2890 //   = Tile(Reshape(Tile(Reshape(x,
2891 //              [I,    J, 1, K]), [1,    1, 2, 1]),
2892 //              [I, 1, J, 2, K]), [1, 2, 1, 1, 1]))
2893 //   = Tile(Reshape(x,
2894 //              [I, 1, J, 1, K]), [1, 2, 1, 2, 1])
2895 //   = Reshape(Tile(x, [1, 2, 2]), [I, 2, J, 2, K])
2896 //
2897 // The outermost reshape is often redundant and can be removed in another pass
2898 class ReplacePackWithTileReshape : public ArithmeticOptimizerStage {
2899  public:
ReplacePackWithTileReshape(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)2900   explicit ReplacePackWithTileReshape(const GraphOptimizerContext& ctx,
2901                                       const ArithmeticOptimizerContext& ctx_ext)
2902       : ArithmeticOptimizerStage("ReplacePackWithTileReshape", ctx, ctx_ext) {}
2903   ~ReplacePackWithTileReshape() override = default;
2904 
IsSupported(const NodeDef * node) const2905   bool IsSupported(const NodeDef* node) const override {
2906     return IsPack(*node) && NumNonControlInputs(*node) > 1 &&
2907            !IsInPreserveSet(*node);
2908   }
2909 
TrySimplify(NodeDef * node,string * simplified_node_name)2910   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
2911     // 1. traverse the chain of Pack ops to get the original input
2912     NodeDef* input = node;
2913     std::vector<const NodeDef*> chain;
2914     while (IsPack(*input) && NumNonControlInputs(*node) > 1 &&
2915            !IsInPreserveSet(*input)) {
2916       // Only pack operations with all identical inputs are supported
2917       if (!AllRegularInputsEqual(*input)) {
2918         break;
2919       }
2920       chain.push_back(input);
2921       TF_RETURN_IF_ERROR(GetInputNode(input->input(0), &input));
2922     }
2923 
2924     // Must be at least two Pack operations to consider for replacement
2925     if (chain.empty()) {
2926       return Status::OK();
2927     }
2928 
2929     // Avoid optimizing the same node twice
2930     const NodeScopeAndName node_scope_and_name =
2931         ParseNodeScopeAndName(node->name());
2932     const string new_const_name =
2933         OptimizedNodeName(node_scope_and_name, "Multiples");
2934     const string new_tile_name = OptimizedNodeName(node_scope_and_name, "Tile");
2935     const string new_shape_name =
2936         OptimizedNodeName(node_scope_and_name, "Shape");
2937     const string new_reshape_name =
2938         OptimizedNodeName(node_scope_and_name, "Reshape");
2939     if (ctx().node_map->NodeExists(new_const_name) ||
2940         ctx().node_map->NodeExists(new_tile_name) ||
2941         ctx().node_map->NodeExists(new_shape_name) ||
2942         ctx().node_map->NodeExists(new_reshape_name)) {
2943       return Status::OK();
2944     }
2945 
2946     // 2. Calculate the multiples and shape tensor using the chain
2947     const OpInfo::TensorProperties* input_props;
2948     TF_RETURN_IF_ERROR(GetTensorProperties(input->name(), &input_props));
2949     const TensorShapeProto& input_shape = input_props->shape();
2950     if (!PartialTensorShape(input_shape).IsFullyDefined()) {
2951       return Status::OK();
2952     }
2953     Tensor multiples(DT_INT32, TensorShape({input_shape.dim_size()}));
2954     TF_RETURN_IF_ERROR(CalculateMultiplesFromChain(chain, &multiples));
2955 
2956     const OpInfo::TensorProperties* output_props;
2957     TF_RETURN_IF_ERROR(GetTensorProperties(node->name(), &output_props));
2958     const TensorShapeProto& output_shape = output_props->shape();
2959     if (!PartialTensorShape(output_shape).IsFullyDefined()) {
2960       return Status::OK();
2961     }
2962     Tensor output_shape_tensor(DT_INT32,
2963                                TensorShape({output_shape.dim_size()}));
2964     for (int i = 0; i < output_shape.dim_size(); ++i) {
2965       output_shape_tensor.flat<int32>()(i) = output_shape.dim(i).size();
2966     }
2967 
2968     // 3. Create constant node with correct multiples value
2969     NodeDef* new_const_node = AddEmptyNode(new_const_name);
2970     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2971         new_const_node->name(), TensorValue(&multiples), new_const_node));
2972     new_const_node->set_device(node->device());
2973     MaybeAddControlInput(input->name(), new_const_node, ctx().optimized_graph,
2974                          ctx().node_map);
2975     AddToOptimizationQueue(new_const_node);
2976 
2977     // 4. Replace the Pack node with Tile(Const(N), input);
2978     DataType dtype = GetDataTypeFromAttr(*node, "T");
2979     NodeDef* new_tile_node = AddEmptyNode(new_tile_name);
2980     new_tile_node->set_op("Tile");
2981     new_tile_node->set_device(node->device());
2982     SetDataTypeToAttr(dtype, "T", new_tile_node);
2983     SetDataTypeToAttr(DT_INT32, "Tmultiples", new_tile_node);
2984     new_tile_node->add_input(input->name());
2985     ctx().node_map->AddOutput(input->name(), new_tile_node->name());
2986     new_tile_node->add_input(new_const_node->name());
2987     ctx().node_map->AddOutput(new_const_node->name(), new_tile_node->name());
2988 
2989     // Tile inherits all control dependencies from the original pack chain
2990     ForwardControlDependencies(new_tile_node, chain);
2991     AddToOptimizationQueue(new_tile_node);
2992 
2993     // 5. Add a new Reshape node to preserve the existing shape
2994     NodeDef* new_shape_node = AddEmptyNode(new_shape_name);
2995     TF_RETURN_IF_ERROR(ConstantFolding::CreateNodeDef(
2996         new_shape_node->name(), TensorValue(&output_shape_tensor),
2997         new_shape_node));
2998     new_shape_node->set_device(node->device());
2999     MaybeAddControlInput(input->name(), new_shape_node, ctx().optimized_graph,
3000                          ctx().node_map);
3001     AddToOptimizationQueue(new_shape_node);
3002 
3003     NodeDef* new_reshape_node = AddEmptyNode(new_reshape_name);
3004     new_reshape_node->set_op("Reshape");
3005     new_reshape_node->set_device(node->device());
3006     SetDataTypeToAttr(dtype, "T", new_reshape_node);
3007     SetDataTypeToAttr(DT_INT32, "Tshape", new_reshape_node);
3008     new_reshape_node->add_input(new_tile_node->name());
3009     ctx().node_map->AddOutput(new_tile_node->name(), new_reshape_node->name());
3010     new_reshape_node->add_input(new_shape_node->name());
3011     ctx().node_map->AddOutput(new_shape_node->name(), new_reshape_node->name());
3012 
3013     *simplified_node_name = new_reshape_node->name();
3014 
3015     return Status::OK();
3016   }
3017 
3018  protected:
CalculateMultiplesFromChain(const std::vector<const NodeDef * > & chain,Tensor * multiples)3019   Status CalculateMultiplesFromChain(const std::vector<const NodeDef*>& chain,
3020                                      Tensor* multiples) {
3021     // Keep track of how the multiples correspond to each shape dimension.
3022     // For example, given Stack([x, x], axis=1) with rank(x) = 3, we start with
3023     //    multiples=[1, 1, 1] , dims=[0, 1, 2]
3024     // After processing the stack op
3025     //    multiples=[1, 2, 1] , dims=[0, 1, 1, 2]
3026     std::vector<int32> dims(multiples->NumElements());
3027     std::iota(dims.begin(), dims.end(), 0);
3028 
3029     for (int i = 0; i < multiples->NumElements(); ++i) {
3030       multiples->flat<int32>()(i) = 1;
3031     }
3032 
3033     for (auto it = chain.rbegin(); it != chain.rend(); ++it) {
3034       AttrSlice attrs(**it);
3035       int64_t axis, n;
3036       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "axis", &axis));
3037       TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "N", &n));
3038 
3039       if (axis >= dims.size()) {
3040         // We don't handle the case where Pack is performed on the last axis,
3041         // e.g. Pack([x, x], axis=3) where rank(x) == 3
3042         return Status(error::OUT_OF_RANGE, "axis value out of range of dims");
3043       }
3044 
3045       int64_t m = multiples->flat<int32>()(dims[axis]) * n;
3046       if (TF_PREDICT_FALSE(m > INT_MAX)) {
3047         return Status(error::OUT_OF_RANGE, "int32 overflow");
3048       }
3049       multiples->flat<int32>()(dims[axis]) = static_cast<int32>(m);
3050 
3051       // Copy index from immediate right of inserted axis
3052       dims.insert(dims.begin() + axis, dims[axis]);
3053     }
3054 
3055     return Status::OK();
3056   }
3057 };
3058 
3059 // Simplify aggregation (e.g. AddN) nodes:
3060 //
3061 // 1. Discard aggregate nodes with a single input and no control dependencies.
3062 //
3063 // 2. Try to rewrite aggregations of N >= 2 identical terms (possibly due to
3064 //    deduping or other rewrites) so we can get rid of the sum entirely.
3065 //
3066 //    The expression (using AddN as an example of an aggregate op):
3067 //      AddN(x, x, x, ... ,x)
3068 //           <-- N terms -->
3069 //    can be rewritten to:
3070 //      Mul(Const(N), x))
3071 //
3072 class SimplifyAggregation : public ArithmeticOptimizerStage {
3073  public:
SimplifyAggregation(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3074   explicit SimplifyAggregation(const GraphOptimizerContext& ctx,
3075                                const ArithmeticOptimizerContext& ctx_ext)
3076       : ArithmeticOptimizerStage("SimplifyAggregation", ctx, ctx_ext) {}
3077   ~SimplifyAggregation() override = default;
3078 
IsSupported(const NodeDef * node) const3079   bool IsSupported(const NodeDef* node) const override {
3080     return IsAggregate(*node) && HasRegularInputs(*node) &&
3081            GetDataTypeFromAttr(*node, "T") !=
3082                DT_VARIANT;  // TODO(b/119787146): Enable for variants.
3083   }
3084 
TrySimplify(NodeDef * node,string * simplified_node_name)3085   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3086     // 1. Discard aggregate nodes with a single input and no control deps.
3087     if (node->input_size() == 1) {
3088       *simplified_node_name = node->input(0);
3089       return Status::OK();
3090     }
3091 
3092     // 2. Rewrite aggregations of N >= 2 identical terms.
3093 
3094     // All non-control inputs must be identical.
3095     bool all_equal = true;
3096     int num_inputs = 1;
3097     for (int i = 1; i < node->input_size(); ++i) {
3098       if (IsControlInput(node->input(i))) break;
3099       ++num_inputs;
3100       if (node->input(i) != node->input(0)) {
3101         all_equal = false;
3102         break;
3103       }
3104     }
3105     if (!all_equal) return Status::OK();
3106 
3107     // And node should not be optimized earlier.
3108     const NodeScopeAndName node_scope_and_name =
3109         ParseNodeScopeAndName(node->name());
3110     const string optimized_const_name =
3111         OptimizedNodeName(node_scope_and_name, "Const");
3112     const string optimized_mul_name =
3113         OptimizedNodeName(node_scope_and_name, "Mul");
3114 
3115     bool is_already_optimized =
3116         ctx().node_map->NodeExists(optimized_const_name) ||
3117         ctx().node_map->NodeExists(optimized_mul_name);
3118 
3119     if (is_already_optimized) return Status::OK();
3120 
3121     // At this point all preconditions are met, and we safely do the rewrite.
3122     VLOG(3) << "Simplify aggregation with identical inputs: node="
3123             << node->name() << " num_inputs=" << num_inputs;
3124 
3125     // 1. Create constant node with value N.
3126     const auto type = GetDataTypeFromAttr(*node, "T");
3127     Tensor t(type, TensorShape({}));
3128     Status status = SetTensorValue(type, num_inputs, &t);
3129     if (!status.ok()) {
3130       return errors::Internal("Failed to create const node: ",
3131                               status.error_message());
3132     }
3133 
3134     TensorValue value(&t);
3135     NodeDef* new_const_node = AddEmptyNode(optimized_const_name);
3136     status = ConstantFolding::CreateNodeDef(new_const_node->name(), value,
3137                                             new_const_node);
3138     if (!status.ok()) {
3139       return errors::Internal("Failed to create const node: ",
3140                               status.error_message());
3141     }
3142     new_const_node->set_device(node->device());
3143     MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
3144                          ctx().optimized_graph, ctx().node_map);
3145     AddToOptimizationQueue(new_const_node);
3146 
3147     // 2. Replace the aggregate node with Mul(Const(N), x).
3148     NodeDef* new_mul_node = AddEmptyNode(optimized_mul_name);
3149     new_mul_node->set_op("Mul");
3150     new_mul_node->set_device(node->device());
3151     SetDataTypeToAttr(type, "T", new_mul_node);
3152     new_mul_node->add_input(new_const_node->name());
3153     ctx().node_map->AddOutput(new_const_node->name(), new_mul_node->name());
3154     new_mul_node->add_input(node->input(0));
3155     ctx().node_map->AddOutput(node->input(0), new_mul_node->name());
3156 
3157     ForwardControlDependencies(new_mul_node, {node});
3158     *simplified_node_name = new_mul_node->name();
3159 
3160     return Status::OK();
3161   }
3162 };
3163 
3164 class ConvertPowStage : public ArithmeticOptimizerStage {
3165  public:
ConvertPowStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3166   explicit ConvertPowStage(const GraphOptimizerContext& ctx,
3167                            const ArithmeticOptimizerContext& ctx_ext)
3168       : ArithmeticOptimizerStage("ConvertPow", ctx, ctx_ext) {}
3169 
IsSupported(const NodeDef * node) const3170   bool IsSupported(const NodeDef* node) const override {
3171     return IsPow(*node) &&
3172            ctx().graph_properties->HasOutputProperties(node->name()) &&
3173            ctx().graph_properties->HasInputProperties(node->name());
3174   }
3175 
TrySimplify(NodeDef * node,string * simplified_node_name)3176   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3177     Tensor pow;
3178     if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
3179     complex128 prev, curr;
3180     for (int i = 0; i < pow.NumElements(); ++i) {
3181       if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
3182         // input data type is not supported by Pow. Skip.
3183         return Status::OK();
3184       }
3185       if (i != 0 && curr != prev) {
3186         // pow has different values on different elements. Skip.
3187         return Status::OK();
3188       }
3189       prev = curr;
3190     }
3191     NodeDef *x, *y;
3192     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
3193     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
3194     const auto& value_props =
3195         ctx().graph_properties->GetInputProperties(node->name())[0];
3196     const TensorShapeProto& output_shape =
3197         ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
3198     if (curr == complex128(2, 0)) {
3199       node->set_op("Square");
3200       node->set_input(1, AsControlDependency(y->name()));
3201       AddToOptimizationQueue(node);
3202       AddToOptimizationQueue(y);
3203     } else if (curr == complex128(3, 0)) {
3204       // TODO(courbet): Use 'Cube' when it's added to TF ops.
3205       if (NodeIsOnCpu(*node)) {
3206         // We create an inner square node: inner_square = square(x)
3207         const NodeScopeAndName scope_and_name =
3208             ParseNodeScopeAndName(node->name());
3209         const string inner_square_name =
3210             OptimizedNodeName(scope_and_name, "_inner");
3211         NodeDef* inner_square_node = ctx().node_map->GetNode(inner_square_name);
3212         if (inner_square_node == nullptr) {
3213           inner_square_node = AddCopyNode(inner_square_name, node);
3214           inner_square_node->set_op("Square");
3215           inner_square_node->mutable_input()->RemoveLast();
3216         }
3217         ctx().node_map->AddOutput(x->name(), inner_square_node->name());
3218         // We modify `node`: node = mul(x, inner_square);
3219         node->set_op("Mul");
3220         node->set_input(1, inner_square_node->name());
3221         node->add_input(AsControlDependency(y->name()));
3222 
3223         AddToOptimizationQueue(node);
3224         AddToOptimizationQueue(inner_square_node);
3225         AddToOptimizationQueue(y);
3226       }
3227     } else if (curr == complex128(1, 0) &&
3228                ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
3229       // Pow could be used to broadcast, so make sure the shapes of the two
3230       // arguments are identical before replacing Pow with Identity.
3231       node->set_op("Identity");
3232       node->set_input(1, AsControlDependency(y->name()));
3233       AddToOptimizationQueue(node);
3234       AddToOptimizationQueue(y);
3235     } else if (curr == complex128(0.5, 0)) {
3236       node->set_op("Sqrt");
3237       node->set_input(1, AsControlDependency(y->name()));
3238       AddToOptimizationQueue(node);
3239       AddToOptimizationQueue(y);
3240     } else if (curr == complex128(0, 0) &&
3241                ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
3242                PartialTensorShape(output_shape).IsFullyDefined()) {
3243       const auto dtype = node->attr().at("T").type();
3244       Tensor ones(dtype, output_shape);
3245       for (int i = 0; i < ones.NumElements(); ++i) {
3246         TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
3247       }
3248       node->set_op("Const");
3249       (*node->mutable_attr())["dtype"].set_type(dtype);
3250       node->mutable_attr()->erase("T");
3251       ones.AsProtoTensorContent(
3252           (*node->mutable_attr())["value"].mutable_tensor());
3253       node->set_input(0, AsControlDependency(x->name()));
3254       node->set_input(1, AsControlDependency(y->name()));
3255       AddToOptimizationQueue(node);
3256       AddToOptimizationQueue(x);
3257       AddToOptimizationQueue(y);
3258     } else if (curr == complex128(-0.5, 0)) {
3259       node->set_op("Rsqrt");
3260       node->set_input(1, AsControlDependency(y->name()));
3261       AddToOptimizationQueue(node);
3262       AddToOptimizationQueue(y);
3263     } else if (curr == complex128(-1, 0)) {
3264       node->set_op("Reciprocal");
3265       node->set_input(1, AsControlDependency(y->name()));
3266       AddToOptimizationQueue(node);
3267       AddToOptimizationQueue(y);
3268     }
3269     return Status::OK();
3270   }
3271 
3272  private:
SetElementToOne(int i,Tensor * t)3273   Status SetElementToOne(int i, Tensor* t) {
3274     switch (t->dtype()) {
3275       case DT_INT32:
3276         t->flat<int32>()(i) = 1;
3277         return Status::OK();
3278       case DT_INT64:
3279         t->flat<int64>()(i) = 1L;
3280         return Status::OK();
3281       case DT_FLOAT:
3282         t->flat<float>()(i) = 1.0f;
3283         return Status::OK();
3284       case DT_DOUBLE:
3285         t->flat<double>()(i) = 1.0;
3286         return Status::OK();
3287       case DT_COMPLEX64:
3288         t->flat<complex64>()(i) = complex64(1);
3289         return Status::OK();
3290       case DT_COMPLEX128:
3291         t->flat<complex128>()(i) = complex128(1);
3292         return Status::OK();
3293       default:
3294         return errors::InvalidArgument("Invalid data type: ", t->dtype());
3295     }
3296   }
3297 };
3298 
3299 class ConvertLog1pStage : public ArithmeticOptimizerStage {
3300  public:
ConvertLog1pStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3301   explicit ConvertLog1pStage(const GraphOptimizerContext& ctx,
3302                              const ArithmeticOptimizerContext& ctx_ext)
3303       : ArithmeticOptimizerStage("ConvertLog1p", ctx, ctx_ext) {}
3304   ~ConvertLog1pStage() override = default;
3305 
IsSupported(const NodeDef * node) const3306   bool IsSupported(const NodeDef* node) const override { return IsLog(*node); }
3307 
TrySimplify(NodeDef * node,string * simplified_node_name)3308   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3309     NodeDef* input;
3310     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input));
3311     if (!IsAdd(*input)) {
3312       return Status::OK();
3313     }
3314 
3315     if (ctx().graph_properties->GetInputProperties(input->name()).size() < 2) {
3316       return Status::OK();
3317     }
3318 
3319     bool modified = false;
3320     TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 0, 1, &modified));
3321     if (!modified) {
3322       TF_RETURN_IF_ERROR(TrySimplifyInternal(node, input, 1, 0, &modified));
3323     }
3324     if (modified) {
3325       *simplified_node_name = node->name();
3326     }
3327     return Status::OK();
3328   }
3329 
3330  private:
TrySimplifyInternal(NodeDef * node,NodeDef * add_node,int i,int j,bool * modified)3331   Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
3332                              bool* modified) {
3333     const auto& t =
3334         ctx().graph_properties->GetInputProperties(add_node->name())[i];
3335     const auto& c =
3336         ctx().graph_properties->GetInputProperties(add_node->name())[j];
3337     for (int k = 0; k < c.shape().dim_size(); ++k) {
3338       // Skip if c shape is not fully determined.
3339       if (c.shape().dim(k).size() < 0) {
3340         return Status::OK();
3341       }
3342     }
3343     TensorShapeProto broadcast_shape;
3344     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3345       return Status::OK();
3346     }
3347     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3348       // skip if the non-constant tensor doesn't have the same shape after
3349       // broadcast.
3350       return Status::OK();
3351     }
3352     Tensor constant;
3353     if (GetTensorFromConstNode(add_node->input(j), &constant)) {
3354       complex128 element;
3355       // TODO(rmlarsen): Refactor the more general IsOnes from
3356       // constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
3357       // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
3358       // and Neg(-1) to 1.
3359       for (int k = 0; k < constant.NumElements(); ++k) {
3360         if (!GetElementUnexhaustive(constant, k,
3361                                     {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3362                                      DT_COMPLEX64, DT_COMPLEX128},
3363                                     &element)) {
3364           // input data type is not supported by log1p. Skip.
3365           return Status::OK();
3366         }
3367         if (element != complex128(1)) {
3368           // current element is not 1. Skip.
3369           return Status::OK();
3370         }
3371       }
3372       NodeDef *x, *y;
3373       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
3374       TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
3375       node->set_op("Log1p");
3376       node->set_input(0, add_node->input(i));
3377       node->add_input(AsControlDependency(y->name()));
3378       ForwardControlDependencies(node, {add_node});
3379 
3380       AddToOptimizationQueue(node);
3381       AddToOptimizationQueue(add_node);
3382       AddToOptimizationQueue(x);
3383       AddToOptimizationQueue(y);
3384       *modified = true;
3385     }
3386     return Status::OK();
3387   }
3388 };
3389 
3390 class ConvertExpm1Stage : public ArithmeticOptimizerStage {
3391  public:
ConvertExpm1Stage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3392   explicit ConvertExpm1Stage(const GraphOptimizerContext& ctx,
3393                              const ArithmeticOptimizerContext& ctx_ext)
3394       : ArithmeticOptimizerStage("ConvertExpm1", ctx, ctx_ext) {}
3395   ~ConvertExpm1Stage() override = default;
3396 
IsSupported(const NodeDef * node) const3397   bool IsSupported(const NodeDef* node) const override {
3398     if (!IsSub(*node)) return false;
3399 
3400     NodeDef* input;
3401     if (!GetInputNode(node->input(0), &input).ok()) return false;
3402 
3403     return IsExp(*input);
3404   }
3405 
TrySimplify(NodeDef * node,string * simplified_node_name)3406   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3407     if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
3408       return Status::OK();
3409     }
3410     const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
3411     const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
3412     TensorShapeProto broadcast_shape;
3413     if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
3414       return Status::OK();
3415     }
3416     if (!ShapesSymbolicallyEqual(t.shape(), broadcast_shape)) {
3417       // skip if the non-constant tensor doesn't have the same shape after
3418       // broadcast.
3419       return Status::OK();
3420     }
3421     Tensor constant;
3422     if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
3423     // TODO(rmlarsen): Use the more general IsOnes helper here.
3424     complex128 element;
3425     for (int k = 0; k < constant.NumElements(); ++k) {
3426       if (!GetElementUnexhaustive(constant, k,
3427                                   {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
3428                                    DT_COMPLEX64, DT_COMPLEX128},
3429                                   &element)) {
3430         // input data type is not supported by expm1. Skip.
3431         return Status::OK();
3432       }
3433       if (element != complex128(1)) {
3434         // current element is not 1. Skip.
3435         return Status::OK();
3436       }
3437     }
3438     NodeDef* exp;
3439     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
3440     NodeDef *exp_input, *ones;
3441     TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
3442     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
3443     node->set_op("Expm1");
3444     node->set_input(0, exp->input(0));
3445     node->set_input(1, AsControlDependency(ones->name()));
3446     ForwardControlDependencies(node, {exp});
3447 
3448     AddToOptimizationQueue(node);
3449     AddToOptimizationQueue(exp);
3450     AddToOptimizationQueue(exp_input);
3451     AddToOptimizationQueue(ones);
3452     *simplified_node_name = node->name();
3453     return Status::OK();
3454   }
3455 };
3456 
3457 // Performs conversions like:
3458 // Max(Sqrt(x)) => Sqrt(Max(x))
3459 // Checks for a max/min reduction over element-wise monotonic functions, such
3460 // as Sqrt, Sigmoid, Tanh, etc.
3461 class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
3462  public:
OptimizeMaxOrMinOfMonotonicStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3463   explicit OptimizeMaxOrMinOfMonotonicStage(
3464       const GraphOptimizerContext& ctx,
3465       const ArithmeticOptimizerContext& ctx_ext)
3466       : ArithmeticOptimizerStage("OptimizeMaxOrMinOfMonotonicStage", ctx,
3467                                  ctx_ext) {}
3468   ~OptimizeMaxOrMinOfMonotonicStage() override = default;
3469 
IsSupported(const NodeDef * node) const3470   bool IsSupported(const NodeDef* node) const override {
3471     return IsAnyMax(*node) || IsAnyMin(*node) || IsAnyMaxPool(*node) ||
3472            IsArgMax(*node) || IsArgMin(*node);
3473   }
3474 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)3475   Status TrySimplify(NodeDef* reduction_node,
3476                      string* simplified_node_name) override {
3477     if (IsInPreserveSet(*reduction_node)) {
3478       return Status::OK();
3479     }
3480 
3481     NodeDef* inner_function;
3482     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
3483 
3484     NodeDef* inner_function_input = nullptr;
3485     if (inner_function->input_size() > 0) {
3486       TF_RETURN_IF_ERROR(
3487           GetInputNode(inner_function->input(0), &inner_function_input));
3488     }
3489 
3490     // Optimize only if:
3491     // 0. inner_function is not in the preserve set,
3492     // 1. inner_function's Op is element-wise monotonic
3493     // 2. inner_function's output is not being consumed elsewhere.
3494     // 3. is monotonic increasing if reduction_node is a pooling operation
3495     //    since we don't have MinPool operations.
3496     // 4. inner_functions is not a Relu node with an input from FusedBatchNorm
3497     //    or BiasAdd. This pattern will be fused later by remapper.
3498     auto can_be_fused_by_remapper = [](const NodeDef& consumer,
3499                                        const NodeDef& producer) -> bool {
3500       if (IsRelu(consumer) || IsRelu6(consumer)) {
3501         if (IsFusedBatchNorm(producer) || IsBiasAdd(producer)) {
3502           return true;
3503         }
3504       }
3505       return false;
3506     };
3507     bool is_non_decreasing = false;
3508     if (!IsInPreserveSet(*inner_function) &&
3509         IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
3510         ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
3511         (is_non_decreasing || !IsAnyMaxPool(*reduction_node)) &&
3512         !can_be_fused_by_remapper(*inner_function, *inner_function_input)) {
3513       // Swap the first inputs of the inner function Op & the reduction Op.
3514       NodeDef* inner_input;
3515       TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
3516       reduction_node->set_input(0, inner_input->name());
3517       ctx().node_map->UpdateInput(reduction_node->name(),
3518                                   inner_function->name(), inner_input->name());
3519       inner_function->set_input(0, reduction_node->name());
3520       TF_RETURN_IF_ERROR(
3521           UpdateConsumers(reduction_node, inner_function->name()));
3522       ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
3523                                   reduction_node->name());
3524       if (!is_non_decreasing) {
3525         // Flip Min<->Max if the function is non-increasing, e.g.
3526         // Max(Neg(x)) = Neg(Min(x)).
3527         const string opposite = FlipMinMax(*reduction_node);
3528         reduction_node->set_op(opposite);
3529       }
3530 
3531       if (IsArgMax(*reduction_node) || IsArgMin(*reduction_node)) {
3532         // ArgMax(Sqrt(x)) = ArgMax(x)
3533         inner_function->set_op("Identity");
3534       }
3535 
3536       AddToOptimizationQueue(reduction_node);
3537       AddToOptimizationQueue(inner_function);
3538       AddToOptimizationQueue(inner_input);
3539     }
3540     return Status::OK();
3541   }
3542 
3543  private:
FlipMinMax(const NodeDef & node)3544   string FlipMinMax(const NodeDef& node) {
3545     const string& op = node.op();
3546     if (IsAnyMax(node) || IsArgMax(node)) {
3547       return str_util::StringReplace(op, "Max", "Min", false);
3548     } else {
3549       return str_util::StringReplace(op, "Min", "Max", false);
3550     }
3551   }
3552 };
3553 
3554 // Replace a chain of type&shape preserving unary ops with a
3555 // '_UnaryOpsComposition' node.
3556 // TODO(ezhulenev): It should be a part of remapper optimizer because it doesn't
3557 // have to do much with arithmetic (together with FoldMultiplyIntoConv stage?).
3558 class UnaryOpsComposition : public ArithmeticOptimizerStage {
3559  public:
UnaryOpsComposition(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3560   explicit UnaryOpsComposition(const GraphOptimizerContext& ctx,
3561                                const ArithmeticOptimizerContext& ctx_ext)
3562       : ArithmeticOptimizerStage("UnaryOpsComposition", ctx, ctx_ext) {
3563     // WARN: This should be consistent with unary_ops_composition.cc.
3564     // clang-format off
3565     supported_ops_ = {// Ops defined via Eigen scalar ops.
3566                       {"Abs",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3567                       {"Acos",       {DT_FLOAT,          DT_DOUBLE}},
3568                       {"Acosh",      {DT_FLOAT,          DT_DOUBLE}},
3569                       {"Asin",       {DT_FLOAT,          DT_DOUBLE}},
3570                       {"Asinh",      {DT_FLOAT,          DT_DOUBLE}},
3571                       {"Atan",       {DT_FLOAT,          DT_DOUBLE}},
3572                       {"Atanh",      {DT_FLOAT,          DT_DOUBLE}},
3573                       {"Ceil",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3574                       {"Cos",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3575                       {"Cosh",       {DT_FLOAT,          DT_DOUBLE}},
3576                       {"Expm1",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3577                       {"Exp",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3578                       {"Floor",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3579                       {"Inv",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3580                       {"Log",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3581                       {"Log1p",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3582                       {"Neg",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3583                       {"Reciprocal", {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3584                       {"Rint",       {DT_FLOAT,          DT_DOUBLE}},
3585                       {"Round",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3586                       {"Rsqrt",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3587                       {"Sigmoid",    {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3588                       {"Sin",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3589                       {"Sinh",       {DT_FLOAT,          DT_DOUBLE}},
3590                       {"Sqrt",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3591                       {"Square",     {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3592                       {"Tan",        {DT_FLOAT,          DT_DOUBLE}},
3593                       {"Tanh",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3594                       // Additional ops that are not part of the Eigen.
3595                       {"Elu",        {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3596                       {"Relu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3597                       {"Relu6",      {DT_FLOAT, DT_HALF, DT_DOUBLE}},
3598                       {"Selu",       {DT_FLOAT, DT_HALF, DT_DOUBLE}}};
3599     // clang-format on
3600   }
3601   ~UnaryOpsComposition() override = default;
3602 
IsSupported(const NodeDef * node) const3603   bool IsSupported(const NodeDef* node) const override {
3604     return CanOptimize(*node) &&
3605            // Check that this node was not already a root of a fused chain. If
3606            // graph optimization runs twice without pruning in between,
3607            // fused_nodes_ will not have this information.
3608            !ctx().node_map->NodeExists(OptimizedNodeName(*node));
3609   }
3610 
TrySimplify(NodeDef * root,string * simplified_node_name)3611   Status TrySimplify(NodeDef* root, string* simplified_node_name) override {
3612     TF_RETURN_IF_ERROR(CheckAttrExists(*root, "T"));
3613     DataType dtype = root->attr().at("T").type();
3614 
3615     // Keep a trace of all supported input nodes that can be fused together.
3616     std::vector<string> op_nodes = {root->name()};
3617     std::vector<string> op_names = {root->op()};
3618 
3619     // Check if we should follow input(0) while building an op composition.
3620     const auto predicate_fn = [&](const NodeDef& input) {
3621       if (input.name() == root->name()) return true;
3622 
3623       bool follow_input_node =
3624           dtype == GetDataTypeFromAttr(input, "T") &&
3625           NumNonControlDataOutputs(input, *ctx().node_map) == 1 &&
3626           CanOptimize(input);
3627 
3628       if (follow_input_node) {
3629         op_nodes.push_back(input.name());
3630         op_names.push_back(input.op());
3631       }
3632 
3633       return follow_input_node;
3634     };
3635 
3636     NodeDef* last_op = GetTailOfChain(
3637         *root, *ctx().node_map, /*follow_control_input*/ false, predicate_fn);
3638 
3639     // We were not able to find a chain that can be replaced.
3640     if (op_names.size() == 1) return Status::OK();
3641 
3642     // Do not add fused nodes to any other chain.
3643     std::for_each(op_nodes.begin(), op_nodes.end(),
3644                   [this](const string& name) { AddToFusedNodes(name); });
3645 
3646     // Reverse the trace to get correct composition computation order.
3647     std::reverse(op_names.begin(), op_names.end());
3648 
3649     VLOG(2) << "Fuse unary ops: root=" << root->name() << " op_names=["
3650             << absl::StrJoin(op_names, ", ") << "]";
3651 
3652     NodeDef* composition_node = ctx().optimized_graph->add_node();
3653     composition_node->set_name(OptimizedNodeName(*root));
3654     composition_node->set_op("_UnaryOpsComposition");
3655     composition_node->add_input(last_op->input(0));
3656     composition_node->set_device(root->device());
3657 
3658     auto attr = composition_node->mutable_attr();
3659     SetAttrValue(dtype, &(*attr)["T"]);
3660     SetAttrValue(op_names, &(*attr)["op_names"]);
3661 
3662     ctx().node_map->AddNode(composition_node->name(), composition_node);
3663     ctx().node_map->AddOutput(NodeName(last_op->input(0)),
3664                               composition_node->name());
3665 
3666     *simplified_node_name = composition_node->name();
3667 
3668     return Status::OK();
3669   }
3670 
3671  private:
CanOptimize(const NodeDef & node) const3672   bool CanOptimize(const NodeDef& node) const {
3673     DataType dtype = GetDataTypeFromAttr(node, "T");
3674     if (!IsSupported(node.op(), dtype)) {
3675       return false;
3676     }
3677     if (IsInPreserveSet(node)) {
3678       return false;
3679     }
3680     if (!NodeIsOnCpu(node)) {
3681       return false;
3682     }
3683     if (NodeIsAlreadyFused(node)) {
3684       return false;
3685     }
3686     return !(IsDrivenByControlDependency(node) ||
3687              DrivesControlDependency(node));
3688   }
3689 
NodeIsAlreadyFused(const NodeDef & node) const3690   bool NodeIsAlreadyFused(const NodeDef& node) const {
3691     return fused_nodes_.count(node.name()) > 0;
3692   }
3693 
OptimizedNodeName(const NodeDef & node) const3694   string OptimizedNodeName(const NodeDef& node) const {
3695     return strings::StrCat(node.name(), "/unary_ops_composition");
3696   }
3697 
AddToFusedNodes(const string & name)3698   void AddToFusedNodes(const string& name) { fused_nodes_.insert(name); }
3699 
3700   // Check if an op is supported by the _UnaryOpsComposition for the given type.
IsSupported(const string & op_name,DataType dtype) const3701   bool IsSupported(const string& op_name, DataType dtype) const {
3702     const auto it = supported_ops_.find(op_name);
3703     return it != supported_ops_.end() && it->second.count(dtype) > 0;
3704   }
3705 
3706   std::unordered_map<string, std::set<DataType>> supported_ops_;
3707   std::unordered_set<string> fused_nodes_;
3708 };
3709 
3710 // Replace operations of the form:
3711 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i,...]
3712 // with
3713 //    a_i
3714 // when the strided slice index `i` is applied in the k'th axis.
3715 //
3716 // Similarly, replace operations of the form:
3717 //    x = stack((a_0, a_1, ..., a_{n-1}), axis=k)[:,...,i:i+1,...]
3718 // with
3719 //    expand_dims(a_i, axis=k)
3720 // where the slice operator can be StridedSlice or Slice.
3721 //
3722 // TODO(ebrevdo): Extend to also replace operations of the form
3723 //    concat((a_0, a_1, ..., ), axis=k)[:, ..., s_i:s_{i+1}, ...]
3724 // with
3725 //    a_i,
3726 // when
3727 //    s_i = cumsum(shape(a)[k] for a in (a_0, ...,))[i]
3728 // and slicing is in the k'th axis.
3729 class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
3730  public:
RemoveStackSliceSameAxis(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)3731   explicit RemoveStackSliceSameAxis(const GraphOptimizerContext& ctx,
3732                                     const ArithmeticOptimizerContext& ctx_ext)
3733       : ArithmeticOptimizerStage("RemoveStackStridedSliceSameAxis", ctx,
3734                                  ctx_ext) {}
3735   ~RemoveStackSliceSameAxis() override = default;
3736 
IsSupported(const NodeDef * node) const3737   bool IsSupported(const NodeDef* node) const override {
3738     return (IsStridedSlice(*node) || IsSlice(*node)) && !IsInPreserveSet(*node);
3739   }
3740 
TrySimplify(NodeDef * node,string * simplified_node_name)3741   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
3742     // *node is a StridedSlice NodeDef.
3743     NodeDef* pack;
3744 
3745     // Get the input and see if it's a Pack op.
3746     TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &pack));
3747     if (!IsPack(*pack)) return Status::OK();
3748 
3749     bool return_early;
3750     PartialTensorShape pack_output_shape;
3751     int pack_axis;
3752     TF_RETURN_IF_ERROR(
3753         CheckInputs(node, pack, &pack_output_shape, &pack_axis, &return_early));
3754     if (return_early) return Status::OK();
3755 
3756     int64_t slice_start_value;
3757     bool found;
3758     bool must_expand_dims;
3759     TF_RETURN_IF_ERROR(GetSliceAxis(node, pack, pack_output_shape, pack_axis,
3760                                     &slice_start_value, &found,
3761                                     &must_expand_dims));
3762     if (!found) return Status::OK();
3763 
3764     return RewriteGraph(node, pack, slice_start_value, pack_axis,
3765                         must_expand_dims, simplified_node_name);
3766   }
3767 
3768  protected:
CheckInputs(const NodeDef * node,const NodeDef * pack,PartialTensorShape * pack_output_shape,int * pack_axis,bool * return_early)3769   Status CheckInputs(const NodeDef* node, const NodeDef* pack,
3770                      PartialTensorShape* pack_output_shape, int* pack_axis,
3771                      bool* return_early) {
3772     *return_early = true;
3773     TF_RETURN_IF_ERROR(CheckAttrExists(*pack, "axis"));
3774 
3775     *pack_axis = pack->attr().at("axis").i();
3776     auto slice_properties =
3777         ctx().graph_properties->GetInputProperties(node->name());
3778     if (slice_properties.empty() ||
3779         slice_properties[0].shape().unknown_rank()) {
3780       return Status::OK();
3781     }
3782     *pack_output_shape = slice_properties[0].shape();
3783     const int pack_output_rank = pack_output_shape->dims();
3784     if (*pack_axis < 0) {
3785       *pack_axis += pack_output_rank;
3786     }
3787     if (*pack_axis < 0 || *pack_axis >= pack_output_rank) {
3788       return errors::InvalidArgument(
3789           "Pack node (", pack->name(),
3790           ") axis attribute is out of bounds: ", pack->attr().at("axis").i());
3791     }
3792     *return_early = false;
3793     return Status::OK();
3794   }
3795 
GetSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3796   Status GetSliceAxis(const NodeDef* node, const NodeDef* pack,
3797                       const PartialTensorShape& pack_output_shape,
3798                       int pack_axis, int64* slice_start_value, bool* found,
3799                       bool* must_expand_dims) {
3800     *found = false;
3801     if (IsSlice(*node)) {
3802       *must_expand_dims = true;
3803       return GetSimpleSliceAxis(node, pack, pack_output_shape, pack_axis,
3804                                 slice_start_value, found);
3805     } else {
3806       return GetStridedSliceAxis(node, pack, pack_output_shape, pack_axis,
3807                                  slice_start_value, found, must_expand_dims);
3808     }
3809   }
3810 
GetSimpleSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found)3811   Status GetSimpleSliceAxis(const NodeDef* node, const NodeDef* pack,
3812                             const PartialTensorShape& pack_output_shape,
3813                             int pack_axis, int64* slice_start_value,
3814                             bool* found) {
3815     NodeDef* slice_begin;
3816     NodeDef* slice_size;
3817     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3818     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_size));
3819     for (const auto* n : {slice_begin, slice_size}) {
3820       if (!IsReallyConstant(*n)) return Status::OK();
3821     }
3822 
3823     Tensor slice_begin_t;
3824     Tensor slice_size_t;
3825     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3826     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3827       return Status::OK();
3828     }
3829     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_size, "value"));
3830     if (!slice_size_t.FromProto(slice_size->attr().at("value").tensor())) {
3831       return Status::OK();
3832     }
3833 
3834     auto copy_tensor_values_to_vector =
3835         [node](const Tensor& t, gtl::InlinedVector<int64, 4>* vec) {
3836           if (t.dtype() == DT_INT32) {
3837             auto t_flat = t.flat<int32>();
3838             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3839           } else if (t.dtype() == DT_INT64) {
3840             auto t_flat = t.flat<int64>();
3841             vec->assign(&t_flat(0), &t_flat(t.NumElements()));
3842           } else {
3843             return errors::InvalidArgument("Node ", node->name(),
3844                                            " has invalid type for Index attr: ",
3845                                            DataTypeString(t.dtype()));
3846           }
3847           return Status::OK();
3848         };
3849 
3850     gtl::InlinedVector<int64, 4> slice_begin_vec;
3851     gtl::InlinedVector<int64, 4> slice_size_vec;
3852     TF_RETURN_IF_ERROR(
3853         copy_tensor_values_to_vector(slice_begin_t, &slice_begin_vec));
3854     TF_RETURN_IF_ERROR(
3855         copy_tensor_values_to_vector(slice_size_t, &slice_size_vec));
3856 
3857     if (slice_begin_vec.size() != slice_size_vec.size()) {
3858       return errors::InvalidArgument("Node ", node->name(),
3859                                      " has mismatched lengths for begin (",
3860                                      slice_begin_vec.size(), ") and size (",
3861                                      slice_size_vec.size(), ") vectors.");
3862     }
3863     int slice_begin_vec_size = slice_begin_vec.size();
3864     if (!pack_output_shape.unknown_rank() &&
3865         slice_begin_vec_size != pack_output_shape.dims()) {
3866       return Status::OK();
3867     }
3868     if (pack_axis >= slice_begin_vec_size) {
3869       return errors::InvalidArgument(
3870           "Input to node ", node->name(), " had pack_axis ", pack_axis,
3871           " but rank was ", slice_begin_vec_size, ".");
3872     }
3873 
3874     *slice_start_value = slice_begin_vec[pack_axis];
3875     if (slice_size_vec[pack_axis] != 1) {
3876       // Not slicing a single value out.
3877       return Status::OK();
3878     }
3879 
3880     for (int i = 0; i < slice_begin_vec_size; ++i) {
3881       if (i != pack_axis) {
3882         if (slice_begin_vec[i] != 0 ||
3883             !(slice_size_vec[i] == -1 ||
3884               slice_size_vec[i] == pack_output_shape.dim_size(i))) {
3885           // Not slicing on the same axis as the Pack op.
3886           return Status::OK();
3887         }
3888       }
3889     }
3890 
3891     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
3892       return errors::InvalidArgument(
3893           "Node ", node->name(), " requested invalid slice index ",
3894           *slice_start_value, " on axis ", pack_axis,
3895           " from tensor of shape: ", pack_output_shape.DebugString());
3896     }
3897 
3898     *found = true;  // slice_start_value is valid.
3899     return Status::OK();
3900   }
3901 
GetStridedSliceAxis(const NodeDef * node,const NodeDef * pack,const PartialTensorShape & pack_output_shape,int pack_axis,int64 * slice_start_value,bool * found,bool * must_expand_dims)3902   Status GetStridedSliceAxis(const NodeDef* node, const NodeDef* pack,
3903                              const PartialTensorShape& pack_output_shape,
3904                              int pack_axis, int64* slice_start_value,
3905                              bool* found, bool* must_expand_dims) {
3906     TF_RETURN_IF_ERROR(
3907         CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask",
3908                                 "new_axis_mask", "shrink_axis_mask"}));
3909 
3910     const int begin_mask = node->attr().at("begin_mask").i();
3911     const int end_mask = node->attr().at("end_mask").i();
3912     const int ellipsis_mask = node->attr().at("ellipsis_mask").i();
3913     const int new_axis_mask = node->attr().at("new_axis_mask").i();
3914     const int shrink_axis_mask = node->attr().at("shrink_axis_mask").i();
3915 
3916     // Check that the StridedSlice is one of these at pack_axis:
3917     //   [..., i, ...]
3918     //   [..., i:i+1, ...]
3919     //   [..., :1, ...]
3920     //   [..., -1:, ...]
3921     ///  [..., s_{pack_axis}-1:, ...]
3922     NodeDef* slice_begin;
3923     NodeDef* slice_end;
3924     NodeDef* slice_strides;
3925     TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &slice_begin));
3926     TF_RETURN_IF_ERROR(GetInputNode(node->input(2), &slice_end));
3927     TF_RETURN_IF_ERROR(GetInputNode(node->input(3), &slice_strides));
3928 
3929     for (const auto* n : {slice_begin, slice_end, slice_strides}) {
3930       if (!IsReallyConstant(*n)) return Status::OK();
3931     }
3932 
3933     Tensor slice_begin_t;
3934     Tensor slice_end_t;
3935     Tensor slice_strides_t;
3936 
3937     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_begin, "value"));
3938     if (!slice_begin_t.FromProto(slice_begin->attr().at("value").tensor())) {
3939       return Status::OK();
3940     }
3941     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_end, "value"));
3942     if (!slice_end_t.FromProto(slice_end->attr().at("value").tensor())) {
3943       return Status::OK();
3944     }
3945     TF_RETURN_IF_ERROR(CheckAttrExists(*slice_strides, "value"));
3946     if (!slice_strides_t.FromProto(
3947             slice_strides->attr().at("value").tensor())) {
3948       return Status::OK();
3949     }
3950     TensorShape processing_shape;
3951     TensorShape final_shape;
3952     bool is_identity;
3953     bool is_simple_slice;
3954     bool slice_dim0;
3955     gtl::InlinedVector<int64, 4> slice_begin_vec;
3956     gtl::InlinedVector<int64, 4> slice_end_vec;
3957     gtl::InlinedVector<int64, 4> slice_strides_vec;
3958     TF_RETURN_IF_ERROR(ValidateStridedSliceOp(
3959         &slice_begin_t, &slice_end_t, slice_strides_t, pack_output_shape,
3960         begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask,
3961         &processing_shape, &final_shape, &is_identity, &is_simple_slice,
3962         &slice_dim0, &slice_begin_vec, &slice_end_vec, &slice_strides_vec));
3963 
3964     if (!is_simple_slice) return Status::OK();
3965 
3966     int begin_index = -1;
3967     int64_t begin_value = 0;
3968     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3969       const int64_t v = slice_begin_vec[i];
3970       if (v != 0) {
3971         if (begin_index != -1) {
3972           // At least two start values that are nonzero.
3973           return Status::OK();
3974         }
3975         begin_index = i;
3976         begin_value = v;
3977       }
3978     }
3979 
3980     int end_index = -1;
3981     int64_t end_value = 0;
3982     for (int i = 0, end = slice_begin_vec.size(); i < end; ++i) {
3983       const int64_t v = slice_end_vec[i];
3984       if (v != pack_output_shape.dim_size(i)) {
3985         if (end_index != -1) {
3986           // At least two end values that are nonzero.
3987           return Status::OK();
3988         }
3989         end_index = i;
3990         end_value = v;
3991       }
3992     }
3993 
3994     if (begin_index == -1 && end_index == -1) return Status::OK();
3995     if (begin_index != -1 && end_index != -1 && begin_index != end_index) {
3996       // Somehow received different axes for begin/end slicing
3997       return Status::OK();
3998     }
3999     const int slice_axis = (begin_index == -1) ? end_index : begin_index;
4000     if (slice_axis != pack_axis) {
4001       // Not slicing on the same axis as the Pack op.
4002       return Status::OK();
4003     }
4004     *slice_start_value = (begin_index == -1) ? 0 : begin_value;
4005     const int64_t slice_end_value =
4006         (end_index == -1) ? pack_output_shape.dim_size(slice_axis) : end_value;
4007     if (slice_end_value != *slice_start_value + 1) {
4008       // Not slicing a single value out.
4009       return Status::OK();
4010     }
4011 
4012     if (*slice_start_value < 0 || *slice_start_value >= pack->input_size()) {
4013       return errors::InvalidArgument(
4014           "Node ", node->name(), " requested invalid slice index ",
4015           *slice_start_value, " on axis ", slice_axis,
4016           " from tensor of shape: ", pack_output_shape.DebugString());
4017     }
4018 
4019     if (shrink_axis_mask == 0) {
4020       *must_expand_dims = true;
4021     } else if (shrink_axis_mask == (1 << slice_axis)) {
4022       *must_expand_dims = false;
4023     } else {
4024       // Shrinking on a different axis from the one that we are slicing on.
4025       return Status::OK();
4026     }
4027 
4028     *found = true;  // slice_start_value is valid.
4029     return Status::OK();
4030   }
4031 
RewriteGraph(const NodeDef * node,const NodeDef * pack,int64_t slice_start_value,int pack_axis,bool must_expand_dims,string * simplified_node_name)4032   Status RewriteGraph(const NodeDef* node, const NodeDef* pack,
4033                       int64_t slice_start_value, int pack_axis,
4034                       bool must_expand_dims, string* simplified_node_name) {
4035     const string& input_slice = pack->input(slice_start_value);
4036 
4037     const OpInfo::TensorProperties* input_slice_properties;
4038     TF_RETURN_IF_ERROR(GetTensorProperties(pack->input(slice_start_value),
4039                                            &input_slice_properties));
4040     PartialTensorShape input_slice_shape(input_slice_properties->shape());
4041 
4042     const OpInfo::TensorProperties* output_properties;
4043     TF_RETURN_IF_ERROR(GetTensorProperties(
4044         strings::StrCat(node->name(), ":", 0), &output_properties));
4045     PartialTensorShape output_shape(output_properties->shape());
4046     NodeDef* output =
4047         AddEmptyNode(OptimizedNodeName(ParseNodeScopeAndName(node->name())));
4048     if (!must_expand_dims) {
4049       output->set_op("Identity");
4050       output->set_device(node->device());
4051       SetDataTypeToAttr(output_properties->dtype(), "T", output);
4052       output->add_input(input_slice);
4053     } else {
4054       NodeDef* axis = AddEmptyNode(
4055           OptimizedNodeName(ParseNodeScopeAndName(node->name()), "Axis"));
4056       axis->set_op("Const");
4057       axis->set_device(node->device());
4058       // We need to add a control edge from input slice to guarantee that axis
4059       // constant will be executed in the same frame as `input_slice`, otherwise
4060       // ExpandDims might have mismatched input frames.
4061       axis->add_input(absl::StrCat("^", ParseTensorName(input_slice).node()));
4062       auto axis_attr = axis->mutable_attr();
4063       SetDataTypeToAttr(DT_INT32, "dtype", axis);
4064       auto* axis_t = (*axis_attr)["value"].mutable_tensor();
4065       axis_t->set_dtype(DT_INT32);
4066       axis_t->add_int_val(pack_axis);
4067       AddToOptimizationQueue(axis);
4068       output->set_op("ExpandDims");
4069       output->set_device(node->device());
4070       SetDataTypeToAttr(output_properties->dtype(), "T", output);
4071       SetDataTypeToAttr(DT_INT32, "Tdim", output);
4072       output->add_input(input_slice);
4073       output->add_input(axis->name());
4074     }
4075 
4076     // Copy dependencies over.
4077     ForwardControlDependencies(output, {node, pack});
4078     AddToOptimizationQueue(output);
4079     *simplified_node_name = output->name();
4080 
4081     return Status::OK();
4082   }
4083 };
4084 
4085 // Eliminates unnecessary copies during sparse embedding lookup operations.
4086 //
4087 // For non-partitioned variables, the `tf.nn.embedding_lookup_sparse()` function
4088 // generates code of the form:
4089 //
4090 //     embeddings = <a 2D Tensor>
4091 //     sparse_ids = <a tf.int64 SparseTensor>
4092 //     segment_ids = sparse_ids.indices[:, 0]
4093 //     ids, idx = tf.unique(sparse_ids.values)
4094 //     gathered_rows = tf.gather(params, ids)
4095 //     result = tf.sparse.segment_<combiner>(gathered_rows, idx, segment_ids)
4096 //
4097 // In this case, all of the work in `tf.unique()` and `tf.gather()`
4098 // can be avoided by passing the full embeddings to
4099 // `tf.sparse.segment_<combiner>()` and performing the same amount of
4100 // computation (but fewer copies and allocations) as follows:
4101 //
4102 //     embeddings = <a 2D Tensor>
4103 //     sparse_ids = <a tf.int64 SparseTensor>
4104 //     segment_ids = sparse_ids.indices[:, 0]
4105 //     result = tf.sparse.segment_<combiner>(
4106 //          embeddings, sparse_ids.values, segment_ids)
4107 class SimplifyEmbeddingLookupStage : public ArithmeticOptimizerStage {
4108  public:
SimplifyEmbeddingLookupStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4109   explicit SimplifyEmbeddingLookupStage(
4110       const GraphOptimizerContext& ctx,
4111       const ArithmeticOptimizerContext& ctx_ext)
4112       : ArithmeticOptimizerStage("SimplifyEmbeddingLookupStage", ctx, ctx_ext) {
4113   }
4114   ~SimplifyEmbeddingLookupStage() override = default;
4115 
IsSupported(const NodeDef * node) const4116   bool IsSupported(const NodeDef* node) const override {
4117     return IsAnySparseSegmentReduction(*node);
4118   }
4119 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4120   Status TrySimplify(NodeDef* reduction_node,
4121                      string* simplified_node_name) override {
4122     if (IsInPreserveSet(*reduction_node)) return Status::OK();
4123 
4124     // Input 0 (data) of the reduction node must be a tf.gather() on the 0th
4125     // axis.
4126     NodeDef* gather_node = nullptr;
4127     TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &gather_node));
4128     if (!IsGather(*gather_node) || IsInPreserveSet(*gather_node) ||
4129         gather_node->device() != reduction_node->device())
4130       return Status::OK();
4131     if (gather_node->op() == "GatherV2" && !IsAxis0(*gather_node, 2))
4132       return Status::OK();
4133 
4134     // Input 1 (indices) of the gather node must be a tf.unique() on the 0th
4135     // axis.
4136     NodeDef* unique_node = nullptr;
4137     TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(1), &unique_node));
4138     if (!IsUnique(*unique_node) || IsInPreserveSet(*unique_node) ||
4139         unique_node->device() != gather_node->device())
4140       return Status::OK();
4141     if (unique_node->op() == "UniqueV2" && !IsAxis0(*unique_node, 1))
4142       return Status::OK();
4143 
4144     DataType unique_element_type;
4145     TF_RETURN_IF_ERROR(GetNodeAttr(*unique_node, "T", &unique_element_type));
4146 
4147     // Input 1 (indices) of the reduction node must be output 1 of the unique
4148     // node.
4149     const TensorId idx_tensor = ParseTensorName(reduction_node->input(1));
4150     if (idx_tensor != TensorId(unique_node->name(), 1)) return Status::OK();
4151 
4152     // Input 1 (indices) of the reduction node becomes input 0 (x) of the unique
4153     // node.
4154     reduction_node->set_input(1, unique_node->input(0));
4155     ctx().node_map->UpdateInput(reduction_node->name(),
4156                                 reduction_node->input(1),
4157                                 unique_node->input(0));
4158     SetDataTypeToAttr(unique_element_type, "Tidx", reduction_node);
4159 
4160     // Input 0 (data) of the reduction node becomes input 1 (params) of the
4161     // gather node.
4162     const OpInfo::TensorProperties* gather_input_properties;
4163     TF_RETURN_IF_ERROR(
4164         GetTensorProperties(gather_node->input(0), &gather_input_properties));
4165     if (gather_input_properties->dtype() == DT_RESOURCE) {
4166       // If the input is a ResourceGather, we need to add
4167       // ReadVariableOp.
4168       NodeDef* variable_node = nullptr;
4169       TF_RETURN_IF_ERROR(GetInputNode(gather_node->input(0), &variable_node));
4170       NodeDef* read_var_node = ctx().optimized_graph->add_node();
4171       read_var_node->set_name(OptimizedNodeName(
4172           ParseNodeScopeAndName(reduction_node->name()), "ReadVar"));
4173       read_var_node->set_op("ReadVariableOp");
4174       read_var_node->add_input(gather_node->input(0));
4175       read_var_node->set_device(variable_node->device());
4176 
4177       // The Variable and the Gather node should have the same
4178       // dtype, but it might not be set on both nodes.
4179       auto attr = read_var_node->mutable_attr();
4180       if (variable_node->attr().count("dtype")) {
4181         SetAttrValue(variable_node->attr().at("dtype").type(),
4182                      &(*attr)["dtype"]);
4183       }
4184       if (gather_node->attr().count("dtype")) {
4185         SetAttrValue(gather_node->attr().at("dtype").type(), &(*attr)["dtype"]);
4186       }
4187       // Copy the _class attr from the Gather node should it exist in case
4188       // of location constraints with the variable.
4189       if (gather_node->attr().count("_class")) {
4190         (*attr)["_class"] = gather_node->attr().at("_class");
4191       }
4192       if (variable_node->attr().count("shape")) {
4193         SetAttrValue(variable_node->attr().at("shape").shape(),
4194                      &(*attr)["_output_shapes"]);
4195       }
4196 
4197       ctx().node_map->AddNode(read_var_node->name(), read_var_node);
4198       reduction_node->set_input(0, read_var_node->name());
4199       ctx().node_map->UpdateInput(reduction_node->name(),
4200                                   reduction_node->input(0),
4201                                   read_var_node->name());
4202     } else {
4203       reduction_node->set_input(0, gather_node->input(0));
4204       ctx().node_map->UpdateInput(reduction_node->name(),
4205                                   reduction_node->input(0),
4206                                   gather_node->input(0));
4207     }
4208     *simplified_node_name = reduction_node->name();
4209     return Status::OK();
4210   }
4211 
4212  private:
IsAxis0(const NodeDef & node,int axis_input)4213   bool IsAxis0(const NodeDef& node, int axis_input) {
4214     Tensor axis_tensor;
4215     if (!GetTensorFromConstNode(node.input(axis_input), &axis_tensor))
4216       return false;
4217     if (axis_tensor.NumElements() != 1) return false;
4218     if (axis_tensor.dtype() == DT_INT32) {
4219       return axis_tensor.flat<int32>()(0) == 0;
4220     } else if (axis_tensor.dtype() == DT_INT64) {
4221       return axis_tensor.flat<int64>()(0) == 0;
4222     } else {
4223       return false;
4224     }
4225   }
4226 };
4227 
4228 // Eliminates unnecessary casts before sparse segment reduction operations.
4229 //
4230 // Existing graphs and library code would often insert a cast from DT_INT64 to
4231 // DT_INT32 on the indices and/or segment_ids inputs to "SparseSegment*" ops.
4232 // Support for for DT_INT64 indices and/or segment_ids now exists, so we can
4233 // pass the input directly without a cast.
4234 class RemoveCastIntoSegmentReductionStage : public ArithmeticOptimizerStage {
4235  public:
RemoveCastIntoSegmentReductionStage(const GraphOptimizerContext & ctx,const ArithmeticOptimizerContext & ctx_ext)4236   explicit RemoveCastIntoSegmentReductionStage(
4237       const GraphOptimizerContext& ctx,
4238       const ArithmeticOptimizerContext& ctx_ext)
4239       : ArithmeticOptimizerStage("RemoveCastIntoSegmentReductionStage", ctx,
4240                                  ctx_ext) {}
4241   ~RemoveCastIntoSegmentReductionStage() override = default;
4242 
IsSupported(const NodeDef * node) const4243   bool IsSupported(const NodeDef* node) const override {
4244     return IsAnySparseSegmentReduction(*node);
4245   }
4246 
TrySimplify(NodeDef * reduction_node,string * simplified_node_name)4247   Status TrySimplify(NodeDef* reduction_node,
4248                      string* simplified_node_name) override {
4249     if (IsInPreserveSet(*reduction_node)) return Status::OK();
4250 
4251     bool optimized = false;
4252 
4253     // Inputs 1 (indices) and 2 (segment_ids) can be either DT_INT32 or
4254     // DT_INT64.
4255     std::array<std::pair<int, string>, 2> input_details = {
4256         std::make_pair(1, "Tidx"), std::make_pair(2, "Tsegmentids")};
4257 
4258     for (const auto& input : input_details) {
4259       int input_index = input.first;
4260       const string& type_attr_name = input.second;
4261       NodeDef* cast_node = nullptr;
4262       TF_RETURN_IF_ERROR(
4263           GetInputNode(reduction_node->input(input_index), &cast_node));
4264       DataType original_index_type;
4265       if (IsCastFromSupportedType(*cast_node, &original_index_type)) {
4266         reduction_node->set_input(input_index, cast_node->input(0));
4267         ctx().node_map->UpdateInput(reduction_node->name(),
4268                                     reduction_node->input(1),
4269                                     cast_node->input(0));
4270         SetDataTypeToAttr(original_index_type, type_attr_name, reduction_node);
4271         optimized = true;
4272       }
4273     }
4274 
4275     if (optimized) *simplified_node_name = reduction_node->name();
4276     return Status::OK();
4277   }
4278 
4279  private:
IsCastFromSupportedType(const NodeDef & node,DataType * out_input_type)4280   bool IsCastFromSupportedType(const NodeDef& node, DataType* out_input_type) {
4281     if (!IsCast(node)) return false;
4282     if (!GetNodeAttr(node, "SrcT", out_input_type).ok()) return false;
4283     return *out_input_type == DT_INT32 || *out_input_type == DT_INT64;
4284   }
4285 };
4286 
4287 }  // namespace
4288 
SimplifyArithmeticOps(bool can_use_shapes)4289 Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
4290   SetVector<NodeDef*> nodes_to_simplify;
4291   nodes_to_simplify.Reserve(optimized_graph_->node_size());
4292   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
4293     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
4294   }
4295 
4296   const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
4297                                   graph_properties_.get(), node_map_.get(),
4298                                   &feed_nodes_, opt_level_);
4299   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
4300 
4301   // Stop pipeline after first stage returning non-empty simplified tensor
4302   // name.
4303   const auto stop = [](const string& result) { return !result.empty(); };
4304   GraphOptimizerStagePipeline<string> pipeline(stop);
4305   const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
4306 
4307   if (options_.combine_add_to_addn && can_use_shapes)
4308     pipeline.AddStage<AddOpsRewriteStage>(ctx, ctx_ext);
4309   if (options_.fold_conjugate_into_transpose)
4310     pipeline.AddStage<FoldConjugateIntoTranspose>(ctx, ctx_ext);
4311   if (options_.fold_multiply_into_conv)
4312     pipeline.AddStage<FoldMultiplyIntoConv>(ctx, ctx_ext);
4313   if (options_.fold_transpose_into_matmul)
4314     pipeline.AddStage<FoldTransposeIntoMatMul>(ctx, ctx_ext);
4315   if (is_aggressive && options_.hoist_common_factor_out_of_aggregation &&
4316       can_use_shapes)
4317     pipeline.AddStage<HoistCommonFactorOutOfAggregation>(ctx, ctx_ext);
4318   if (options_.minimize_broadcasts && can_use_shapes)
4319     pipeline.AddStage<MinimizeBroadcasts>(ctx, ctx_ext);
4320   if (options_.remove_identity_transpose && can_use_shapes)
4321     pipeline.AddStage<RemoveIdentityTranspose>(ctx, ctx_ext);
4322   if (options_.remove_involution)
4323     pipeline.AddStage<RemoveInvolution>(ctx, ctx_ext);
4324   if (options_.remove_redundant_bitcast)
4325     pipeline.AddStage<RemoveRedundantBitcastStage>(ctx, ctx_ext);
4326   if (options_.remove_redundant_cast)
4327     pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
4328   if (options_.replace_pack_with_tile_reshape)
4329     pipeline.AddStage<ReplacePackWithTileReshape>(ctx, ctx_ext);
4330   if (options_.replace_mul_with_tile && can_use_shapes)
4331     pipeline.AddStage<ReplaceMulWithBroadcastByTile>(ctx, ctx_ext);
4332   if (options_.reduce_upsampling_dims)
4333     pipeline.AddStage<ReduceUpsamplingDims>(ctx, ctx_ext);
4334   if (options_.remove_redundant_reshape && can_use_shapes)
4335     pipeline.AddStage<RemoveRedundantReshapeOrBroadcastTo>(ctx, ctx_ext);
4336   if (options_.remove_negation)
4337     pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
4338   if (options_.replace_mul_with_square)
4339     pipeline.AddStage<ReplaceMulWithSquare>(ctx, ctx_ext);
4340   if (options_.remove_logical_not)
4341     pipeline.AddStage<RemoveLogicalNotStage>(ctx, ctx_ext);
4342   if (options_.reorder_cast_like_and_value_preserving)
4343     pipeline.AddStage<ReorderCastLikeAndValuePreserving>(ctx, ctx_ext);
4344   if (options_.simplify_aggregation)
4345     pipeline.AddStage<SimplifyAggregation>(ctx, ctx_ext);
4346   if (options_.hoist_cwise_unary_chains)
4347     pipeline.AddStage<HoistCWiseUnaryChainsStage>(ctx, ctx_ext);
4348   if (options_.convert_sqrt_div_to_rsqrt_mul)
4349     pipeline.AddStage<SqrtDivToRsqrtMulStage>(ctx, ctx_ext);
4350   if (options_.remove_idempotent)
4351     pipeline.AddStage<RemoveIdempotentStage>(ctx, ctx_ext);
4352   if (options_.convert_pow) pipeline.AddStage<ConvertPowStage>(ctx, ctx_ext);
4353   if (options_.convert_log1p)
4354     pipeline.AddStage<ConvertLog1pStage>(ctx, ctx_ext);
4355   if (options_.convert_log_softmax)
4356     pipeline.AddStage<LogSoftmaxStage>(ctx, ctx_ext);
4357   if (options_.optimize_max_or_min_of_monotonic)
4358     pipeline.AddStage<OptimizeMaxOrMinOfMonotonicStage>(ctx, ctx_ext);
4359   if (options_.convert_expm1)
4360     pipeline.AddStage<ConvertExpm1Stage>(ctx, ctx_ext);
4361   if (options_.unary_ops_composition)
4362     pipeline.AddStage<UnaryOpsComposition>(ctx, ctx_ext);
4363   if (options_.remove_stack_slice_same_axis)
4364     pipeline.AddStage<RemoveStackSliceSameAxis>(ctx, ctx_ext);
4365   if (options_.simplify_embedding_lookup)
4366     pipeline.AddStage<SimplifyEmbeddingLookupStage>(ctx, ctx_ext);
4367   if (options_.remove_cast_into_segment_reduction)
4368     pipeline.AddStage<RemoveCastIntoSegmentReductionStage>(ctx, ctx_ext);
4369   if (options_.fuse_squared_diff)
4370     pipeline.AddStage<FuseSquaredDiffStage>(ctx, ctx_ext);
4371 
4372   VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
4373           << absl::StrJoin(pipeline.StageNames(), ", ");
4374 
4375   while (!nodes_to_simplify.Empty()) {
4376     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4377     NodeDef* node = nodes_to_simplify.PopBack();
4378 
4379     string simplified_tensor = "";
4380     bool optimized = pipeline.PassThroughAllStages(node, &simplified_tensor);
4381 
4382     // If the node was not optimized by any of the stages, go to the next one.
4383     if (!optimized) continue;
4384 
4385     // re-wire consumers of an old node to the new one
4386     if (NodeName(simplified_tensor) != node->name()) {
4387       // Always consider simplified_tensor for further optimizations.
4388       NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
4389       if (simplified_node != nullptr) {
4390         nodes_to_simplify.PushBack(simplified_node);
4391       }
4392       // When `node` is simplified to another node rather than in-place, the
4393       // consumers of `node` are already redirected to `simplified_tensor`.
4394       // Re-push the consumers into `nodes_to_simplify` for further
4395       // optimizations.
4396       const std::vector<NodeDef*> consumers =
4397           node_map_->GetOutputsOrderedByNodeName(node->name());
4398       for (NodeDef* consumer : consumers) {
4399         // Update `consumer`'s use of `node` to `input`'s operand.
4400         for (int i = 0; i < consumer->input_size(); ++i) {
4401           int operand_pos;
4402           string operand_node_name =
4403               ParseNodeName(consumer->input(i), &operand_pos);
4404           if (operand_node_name == node->name()) {
4405             *consumer->mutable_input(i) =
4406                 (operand_pos < 0
4407                      ? AsControlDependency(NodeName(simplified_tensor))
4408                      : simplified_tensor);
4409           }
4410         }
4411         node_map_->UpdateInput(consumer->name(), node->name(),
4412                                simplified_tensor);
4413         nodes_to_simplify.PushBack(consumer);
4414       }
4415     }
4416   }
4417   return Status::OK();
4418 }
4419 
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)4420 Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
4421                                      const GrapplerItem& item,
4422                                      GraphDef* optimized_graph) {
4423   // Set up helper data structures.
4424   nodes_to_preserve_ = item.NodesToPreserve();
4425   fetch_nodes_known_ = !item.fetch.empty();
4426   GrapplerItem optimized_item(item);
4427   optimized_graph_ = &optimized_item.graph;
4428 
4429   node_map_.reset(new NodeMap(optimized_graph_));
4430   for (const auto& feed : item.feed) {
4431     feed_nodes_.insert(NodeName(feed.first));
4432   }
4433 
4434   // // Disable restricted graph rewrites.
4435   options_.unary_ops_composition &=
4436       item.optimization_options().allow_non_differentiable_rewrites;
4437 
4438   // Perform topological sort on the graph in order to help DedupComputations
4439   // and AddOpsRewrite to optimize larger subgraphs starting from the roots
4440   // with more inputs.
4441   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
4442   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4443 
4444   graph_properties_.reset(new GraphProperties(optimized_item));
4445   const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4446   const Status status =
4447       graph_properties_->InferStatically(assume_valid_feeds,
4448                                          /*aggressive_shape_inference=*/false,
4449                                          /*include_tensor_values=*/false);
4450   const bool can_use_shapes = status.ok();
4451   if (!can_use_shapes) {
4452     VLOG(1) << "Shape inference failed." << status.error_message();
4453   }
4454 
4455   // Perform the optimizations.
4456   TF_RETURN_IF_ERROR(SimplifyArithmeticOps(can_use_shapes));
4457   *optimized_graph = std::move(*optimized_graph_);
4458   return Status::OK();
4459 }
4460 
4461 }  // namespace grappler
4462 }  // namespace tensorflow
4463