• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
18 
19 #include <unordered_set>
20 
21 #include "tensorflow/core/grappler/costs/graph_properties.h"
22 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
23 #include "tensorflow/core/grappler/utils.h"
24 #include "tensorflow/core/lib/gtl/flatset.h"
25 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
26 
27 namespace tensorflow {
28 namespace grappler {
29 
30 constexpr char kArithmeticOptimizer[] = "ArithmeticOptimizer";
31 
32 // Optimize TF computations by reducing the arithmetic complexity required to
33 // run a model.
34 class ArithmeticOptimizer : public GraphOptimizer {
35  public:
ArithmeticOptimizer()36   ArithmeticOptimizer()
37       : opt_level_(RewriterConfig::ON),
38         options_(ArithmeticOptimizerOptions::Default(RewriterConfig::ON)) {}
39 
ArithmeticOptimizer(RewriterConfig::Toggle opt_level)40   explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level)
41       : opt_level_(opt_level),
42         options_(ArithmeticOptimizerOptions::Default(opt_level)) {}
43 
~ArithmeticOptimizer()44   ~ArithmeticOptimizer() override {}
45 
name()46   string name() const override { return "arithmetic_optimizer"; };
47 
UsesFunctionLibrary()48   bool UsesFunctionLibrary() const override { return false; }
49 
50   Status Optimize(Cluster* cluster, const GrapplerItem& item,
51                   GraphDef* optimized_graph) override;
52 
53  private:
54   friend class ArithmeticOptimizerTest;
55 
56   // Granular control for arithmetic optimizer stages
57   struct ArithmeticOptimizerOptions {
58     bool combine_add_to_addn = true;
59     bool convert_sqrt_div_to_rsqrt_mul = true;
60     bool dedup_computations = true;
61     bool fold_conjugate_into_transpose = true;
62     bool fold_multiply_into_conv = true;
63     bool fold_transpose_into_matmul = true;
64     bool fuse_squared_diff = true;
65     bool hoist_common_factor_out_of_aggregation = true;
66     bool hoist_cwise_unary_chains = true;
67     bool minimize_broadcasts = true;
68     bool optimize_max_or_min_of_monotonic = true;
69     bool remove_idempotent = true;
70     bool remove_identity_transpose = true;
71     bool remove_involution = true;
72     bool remove_logical_not = true;
73     bool remove_negation = true;
74     bool remove_redundant_bitcast = true;
75     bool remove_redundant_cast = true;
76     bool remove_redundant_reshape = true;
77     bool reduce_upsampling_dims = true;
78     bool reorder_cast_like_and_value_preserving = true;
79     bool replace_mul_with_tile = true;
80     bool replace_mul_with_square = true;
81     bool replace_pack_with_tile_reshape = true;
82     bool convert_pow = true;
83     bool convert_log1p = true;
84     bool convert_log_softmax = true;
85     bool convert_expm1 = true;
86     bool unary_ops_composition = true;
87     bool remove_stack_slice_same_axis = true;
88     bool simplify_aggregation = true;
89     bool simplify_embedding_lookup = true;
90     bool remove_cast_into_segment_reduction = true;
91 
92     // Choose which arithmetic optimizer stages will be enabled for a given
93     // optimization level by default.
DefaultArithmeticOptimizerOptions94     static ArithmeticOptimizerOptions Default(
95         RewriterConfig::Toggle opt_level) {
96       ArithmeticOptimizerOptions options;
97       return options;
98     }
99   };
100 
101   // Returns true if it is safe to dedup node from the graph.
102   bool CanDedup(const NodeDef& node) const;
103 
104   // Dedup redundant nodes in the graph.
105   void DedupComputations();
106 
107   // Forward the control dependencies anchored on src_nodes to the target_nodes.
108   void ForwardControlDependencies(NodeDef* target_node,
109                                   const std::vector<const NodeDef*>& src_nodes);
110 
111   // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
112   // transposes.
113   Status SimplifyArithmeticOps(bool can_use_shapes);
114   // Tries to simplify the expression that roots at `node` and replaces the uses
115   // of `node` to the simplified expression. Returns the name of the simplified
116   // tensor (e.g. "split:1") or an empty string if no simplification is
117   // performed.
118   //
119   // `node_map` stores the mapping from node names to NodeDef*, and will be
120   // updated according to the rewrite.
121   //
122   // `new_nodes` will be populated with the new nodes this function creates and
123   // updates. The caller can push these nodes into the simplification queue to
124   // optimize them further.
125   //
126   // TODO(jingyue): This interface is not suitable for optimizing nodes with
127   // multiple output tensors. We should pass in a tensor name instead of a
128   // NodeDef.
129   string TrySimplifyAndReplaceUses(const NodeDef* node,
130                                    SetVector<NodeDef*>* nodes_to_simplify);
131 
132   RewriterConfig::Toggle opt_level_;
133   ArithmeticOptimizerOptions options_;
134 
135   bool fetch_nodes_known_ = false;
136   std::unordered_set<string> nodes_to_preserve_;
137   std::unique_ptr<NodeMap> node_map_;
138   std::unique_ptr<GraphProperties> graph_properties_;
139   GraphDef* optimized_graph_ = nullptr;  // Not owned.
140   gtl::FlatSet<string> feed_nodes_;
141 };
142 
143 }  // end namespace grappler
144 }  // end namespace tensorflow
145 
146 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_H_
147