1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Classes to maintain a static registry of whole-graph optimization 17 // passes to be applied by the Session when it initializes a graph. 18 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 19 #define TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 20 21 #include <functional> 22 #include <map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device_set.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/graph/costmodel.h" 28 #include "tensorflow/core/graph/graph.h" 29 30 namespace tensorflow { 31 struct SessionOptions; 32 33 // All the parameters used by an optimization pass are packaged in 34 // this struct. They should be enough for the optimization pass to use 35 // as a key into a state dictionary if it wants to keep state across 36 // calls. 37 struct GraphOptimizationPassOptions { 38 // Filled in by DirectSession for PRE_PLACEMENT optimizations. Can be empty. 39 string session_handle; 40 const SessionOptions* session_options = nullptr; 41 const CostModel* cost_model = nullptr; 42 43 FunctionLibraryDefinition* flib_def = nullptr; // Not owned. 44 // The DeviceSet contains all the devices known to the system and is 45 // filled in for optimizations run by the session master, i.e., 46 // PRE_PLACEMENT, POST_PLACEMENT, and POST_REWRITE_FOR_EXEC. It is 47 // nullptr for POST_PARTITIONING optimizations which are run at the 48 // workers. 49 const DeviceSet* device_set = nullptr; // Not owned. 50 51 // The graph to optimize, for optimization passes that run before 52 // partitioning. Null for post-partitioning passes. 53 // An optimization pass may replace *graph with a new graph object. 54 std::unique_ptr<Graph>* graph = nullptr; 55 56 // Graphs for each partition, if running post-partitioning. Optimization 57 // passes may alter the graphs, but must not add or remove partitions. 58 // Null for pre-partitioning passes. 59 std::unordered_map<string, std::unique_ptr<Graph>>* partition_graphs = 60 nullptr; 61 }; 62 63 // Optimization passes are implemented by inheriting from 64 // GraphOptimizationPass. 65 class GraphOptimizationPass { 66 public: ~GraphOptimizationPass()67 virtual ~GraphOptimizationPass() {} 68 virtual Status Run(const GraphOptimizationPassOptions& options) = 0; set_name(const string & name)69 void set_name(const string& name) { name_ = name; } name()70 string name() const { return name_; } 71 72 private: 73 // The name of the opitimization pass, which is the same as the inherited 74 // class name. 75 string name_; 76 }; 77 78 // The key is a 'phase' number. Phases are executed in increasing 79 // order. Within each phase the order of passes is undefined. 80 typedef std::map<int, std::vector<std::unique_ptr<GraphOptimizationPass>>> 81 GraphOptimizationPasses; 82 83 // A global OptimizationPassRegistry is used to hold all passes. 84 class OptimizationPassRegistry { 85 public: 86 // Groups of passes are run at different points in initialization. 87 enum Grouping { 88 PRE_PLACEMENT, // after cost model assignment, before placement. 89 POST_PLACEMENT, // after placement. 90 POST_REWRITE_FOR_EXEC, // after re-write using feed/fetch endpoints. 91 POST_PARTITIONING, // after partitioning 92 }; 93 94 // Add an optimization pass to the registry. 95 void Register(Grouping grouping, int phase, 96 std::unique_ptr<GraphOptimizationPass> pass); 97 groups()98 const std::map<Grouping, GraphOptimizationPasses>& groups() { 99 return groups_; 100 } 101 102 // Run all passes in grouping, ordered by phase, with the same 103 // options. 104 Status RunGrouping(Grouping grouping, 105 const GraphOptimizationPassOptions& options); 106 107 // Returns the global registry of optimization passes. 108 static OptimizationPassRegistry* Global(); 109 110 // Prints registered optimization passes for debugging. 111 void LogGrouping(Grouping grouping, int vlog_level); 112 void LogAllGroupings(int vlog_level); 113 114 private: 115 std::map<Grouping, GraphOptimizationPasses> groups_; 116 }; 117 118 namespace optimization_registration { 119 120 class OptimizationPassRegistration { 121 public: OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping,int phase,std::unique_ptr<GraphOptimizationPass> pass,string optimization_pass_name)122 OptimizationPassRegistration(OptimizationPassRegistry::Grouping grouping, 123 int phase, 124 std::unique_ptr<GraphOptimizationPass> pass, 125 string optimization_pass_name) { 126 pass->set_name(optimization_pass_name); 127 OptimizationPassRegistry::Global()->Register(grouping, phase, 128 std::move(pass)); 129 } 130 }; 131 132 } // namespace optimization_registration 133 134 #define REGISTER_OPTIMIZATION(grouping, phase, optimization) \ 135 REGISTER_OPTIMIZATION_UNIQ_HELPER(__COUNTER__, grouping, phase, optimization) 136 137 #define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \ 138 REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) 139 140 #define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \ 141 static ::tensorflow::optimization_registration::OptimizationPassRegistration \ 142 register_optimization_##ctr( \ 143 grouping, phase, \ 144 ::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \ 145 new optimization()), \ 146 #optimization) 147 148 } // namespace tensorflow 149 150 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_OPTIMIZATION_REGISTRY_H_ 151