1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes to maintain a static registry of whole-graph optimization 17 // passes to be applied by the Session when it initializes a graph. 18 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 19 #define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 20 21 #include <functional> 22 #include <map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/composite_device.h" 26 #include "tensorflow/core/common_runtime/device_set.h" 27 #include "tensorflow/core/framework/function.h" 28 #include "tensorflow/core/graph/costmodel.h" 29 #include "tensorflow/core/graph/graph.h" 30 31 namespace tensorflow { 32 struct SessionOptions; 33 34 // All the parameters used by an optimization pass are packaged in 35 // this struct. They should be enough for the optimization pass to use 36 // as a key into a state dictionary if it wants to keep state across 37 // calls. 38 struct GraphOptimizationPassOptions { 39 // Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty. 40 string session_handle; 41 const SessionOptions* session_options = nullptr; 42 const CostModel* cost_model = nullptr; 43 44 FunctionLibraryDefinition* flib_def = nullptr; // Not owned. 45 // The DeviceSet contains all the devices known to the system and is 46 // filled in for optimizations run by the session master, i.e., 47 // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is 48 // nullptr for POST_PARTITIONING optimizations which are run at the 49 // workers. 50 const DeviceSet* device_set = nullptr; // Not owned. 51 52 // Maps from a CompositeDevice name to a list of underlying physical 53 // devices. 54 const std::vector<CompositeDevice*>* composite_devices = 55 nullptr; // Not owned. 56 57 // The graph to optimize, for optimization passes that run before 58 // partitioning. Null for post-partitioning passes. 59 // An optimization pass may replace *graph with a new graph object. 60 std::unique_ptr<Graph>* graph = nullptr; 61 62 // Graphs for each partition, if running post-partitioning. Optimization 63 // passes may alter the graphs, but must not add or remove partitions. 64 // Null for pre-partitioning passes. 65 std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs = 66 nullptr; 67 68 // Indicator of whether or not the graph was derived from a function. 69 bool is_function_graph = false; 70 // Set when is_function_graph is true. The default device where the function 71 // runs. If nullptr, it runs on the local host. 72 const Device* default_function_device = nullptr; 73 // Set when is_function_graph is true. The function where the graph was 74 // derived. `graph` doesn't contain all the information in the function_def, 75 // e.g. function attributes. 76 const FunctionDef* function_def = nullptr; 77 }; 78 79 // Optimization passes are implemented by inheriting from 80 // GraphOptimizationPass. 81 class GraphOptimizationPass { 82 public: ~GraphOptimizationPass()83 virtual ~GraphOptimizationPass() {} 84 virtual Status Run(const GraphOptimizationPassOptions& options) = 0; set_name(const string & name)85 void set_name(const string& name) { name_ = name; } name()86 string name() const { return name_; } 87 88 private: 89 // The name of the optimization pass, which is the same as the inherited 90 // class name. 91 string name_; 92 }; 93 94 // The key is a 'phase' number. Phases are executed in increasing 95 // order. Within each phase the order of passes is undefined. 96 typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>> 97 GraphOptimizationPasses; 98 99 // A global OptimizationPassRegistry is used to hold all passes. 100 class OptimizationPassRegistry { 101 public: 102 // Groups of passes are run at different points in initialization. 103 enum Grouping { 104 PRE_PLACEMENT, // after cost model assignment, before placement. 105 POST_PLACEMENT, // after placement. 106 POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints. 107 POST_PARTITIONING, // after partitioning 108 }; 109 110 // Add an optimization pass to the registry. 111 void Register(Grouping grouping, int phase, 112 std::unique_ptr<GraphOptimizationPass> pass); 113 groups()114 const std::map<Grouping, GraphOptimizationPasses>& groups() { 115 return groups_; 116 } 117 118 // Run all passes in grouping, ordered by phase, with the same 119 // options. 120 Status RunGrouping(Grouping grouping, 121 const GraphOptimizationPassOptions& options); 122 123 // Returns the global registry of optimization passes. 124 static OptimizationPassRegistry* Global(); 125 126 // Prints registered optimization passes for debugging. 127 void LogGrouping(Grouping grouping, int vlog_level); 128 void LogAllGroupings(int vlog_level); 129 130 private: 131 std::map<Grouping, GraphOptimizationPasses> groups_; 132 }; 133 134 namespace optimization_registration { 135 136 class OptimizationPassRegistration { 137 public: OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,int phase,std::unique_ptr<GraphOptimizationPass> pass,string optimization_pass_name)138 OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, 139 int phase, 140 std::unique_ptr<GraphOptimizationPass> pass, 141 string optimization_pass_name) { 142 pass->set_name(optimization_pass_name); 143 OptimizationPassRegistry::Global()->Register(grouping, phase, 144 std::move(pass)); 145 } 146 }; 147 148 } // namespace optimization_registration 149 150 #define REGISTER_OPTIMIZATION(grouping, phase, optimization) \ 151 REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization) 152 153 #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \ 154 REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) 155 156 #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \ 157 static ::tensorflow::optimization_registration::OptimizationPassRegistration \ 158 register_optimization_##ctr( \ 159 grouping, phase, \ 160 ::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \ 161 new optimization()), \ 162 #optimization) 163 164 } // namespace tensorflow 165 166 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 167