• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Classes to maintain a static registry of whole-graph optimization
17 // passes to be applied by the Session when it initializes a graph.
18 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
19 #define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
20 
21 #include <functional>
22 #include <map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/composite_device.h"
26 #include "tensorflow/core/common_runtime/device_set.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/graph/costmodel.h"
29 #include "tensorflow/core/graph/graph.h"
30 
31 namespace tensorflow {
32 struct SessionOptions;
33 
34 // All the parameters used by an optimization pass are packaged in
35 // this struct. They should be enough for the optimization pass to use
36 // as a key into a state dictionary if it wants to keep state across
37 // calls.
38 struct GraphOptimizationPassOptions {
39   // Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty.
40   string session_handle;
41   const SessionOptions* session_options = nullptr;
42   const CostModel* cost_model = nullptr;
43 
44   FunctionLibraryDefinition* flib_def = nullptr;  // Not owned.
45   // The DeviceSet contains all the devices known to the system and is
46   // filled in for optimizations run by the session master, i.e.,
47   // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is
48   // nullptr for POST_PARTITIONING optimizations which are run at the
49   // workers.
50   const DeviceSet* device_set = nullptr;  // Not owned.
51 
52   // Maps from a CompositeDevice name to a list of underlying physical
53   // devices.
54   const std::vector<CompositeDevice*>* composite_devices =
55       nullptr;  // Not owned.
56 
57   // The graph to optimize, for optimization passes that run before
58   // partitioning. Null for post-partitioning passes.
59   // An optimization pass may replace *graph with a new graph object.
60   std::unique_ptr<Graph>* graph = nullptr;
61 
62   // Graphs for each partition, if running post-partitioning. Optimization
63   // passes may alter the graphs, but must not add or remove partitions.
64   // Null for pre-partitioning passes.
65   std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs =
66       nullptr;
67 
68   // Indicator of whether or not the graph was derived from a function.
69   bool is_function_graph = false;
70   // Set when is_function_graph is true. The default device where the function
71   // runs. If nullptr, it runs on the local host.
72   const Device* default_function_device = nullptr;
73   // Set when is_function_graph is true. The function where the graph was
74   // derived. `graph` doesn't contain all the information in the function_def,
75   // e.g. function attributes.
76   const FunctionDef* function_def = nullptr;
77 
78   // TODO(b/176491312): Remove this if shape inference on import flag is
79   // removed. If True, allows mlir roundtrip to run shape inference on import.
80   bool shape_inference_on_tfe_dialect_import = true;
81 };
82 
83 // Optimization passes are implemented by inheriting from
84 // GraphOptimizationPass.
85 class GraphOptimizationPass {
86  public:
~GraphOptimizationPass()87   virtual ~GraphOptimizationPass() {}
88   virtual Status Run(const GraphOptimizationPassOptions& options) = 0;
set_name(const string & name)89   void set_name(const string& name) { name_ = name; }
name()90   string name() const { return name_; }
91 
92  private:
93   // The name of the optimization pass, which is the same as the inherited
94   // class name.
95   string name_;
96 };
97 
98 // The key is a 'phase' number. Phases are executed in increasing
99 // order. Within each phase the order of passes is undefined.
100 typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>>
101     GraphOptimizationPasses;
102 
103 // A global OptimizationPassRegistry is used to hold all passes.
104 class OptimizationPassRegistry {
105  public:
106   // Groups of passes are run at different points in initialization.
107   enum Grouping {
108     PRE_PLACEMENT,          // after cost model assignment, before placement.
109     POST_PLACEMENT,         // after placement.
110     POST_REWRITE_FOR_EXEC,  // after re-write using feed/fetch endpoints.
111     POST_PARTITIONING,      // after partitioning
112   };
113 
114   // Add an optimization pass to the registry.
115   void Register(Grouping grouping, int phase,
116                 std::unique_ptr<GraphOptimizationPass> pass);
117 
groups()118   const std::map<Grouping, GraphOptimizationPasses>& groups() {
119     return groups_;
120   }
121 
122   // Run all passes in grouping, ordered by phase, with the same
123   // options.
124   Status RunGrouping(Grouping grouping,
125                      const GraphOptimizationPassOptions& options);
126 
127   // Returns the global registry of optimization passes.
128   static OptimizationPassRegistry* Global();
129 
130   // Prints registered optimization passes for debugging.
131   void LogGrouping(Grouping grouping, int vlog_level);
132   void LogAllGroupings(int vlog_level);
133 
134  private:
135   std::map<Grouping, GraphOptimizationPasses> groups_;
136 };
137 
138 namespace optimization_registration {
139 
140 class OptimizationPassRegistration {
141  public:
OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,int phase,std::unique_ptr<GraphOptimizationPass> pass,string optimization_pass_name)142   OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,
143                                int phase,
144                                std::unique_ptr<GraphOptimizationPass> pass,
145                                string optimization_pass_name) {
146     pass->set_name(optimization_pass_name);
147     OptimizationPassRegistry::Global()->Register(grouping, phase,
148                                                  std::move(pass));
149   }
150 };
151 
152 }  // namespace optimization_registration
153 
154 #define REGISTER_OPTIMIZATION(grouping, phase, optimization) \
155   REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization)
156 
157 #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
158   REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
159 
160 #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)         \
161   static ::tensorflow::optimization_registration::OptimizationPassRegistration \
162       register_optimization_##ctr(                                             \
163           grouping, phase,                                                     \
164           ::std::unique_ptr<::tensorflow::GraphOptimizationPass>(              \
165               new optimization()),                                             \
166           #optimization)
167 
168 }  // namespace tensorflow
169 
170 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_
171