1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ 18 19 #include "tensorflow/core/common_runtime/device_set.h" 20 #include "tensorflow/core/framework/device_base.h" 21 #include "tensorflow/core/framework/function.h" 22 #include "tensorflow/core/graph/graph.h" 23 #include "tensorflow/core/grappler/grappler_item.h" 24 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" 25 #include "tensorflow/core/grappler/verifiers/graph_verifier.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/protobuf/config.pb.h" 28 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 29 #include "tensorflow/core/protobuf/verifier_config.pb.h" 30 31 namespace tensorflow { 32 namespace grappler { 33 34 // Run the other grappler optimizers based on the specified rewriter config. 35 class MetaOptimizer : public GraphOptimizer { 36 public: 37 MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg); 38 ~MetaOptimizer() override = default; 39 name()40 string name() const override { return "meta_optimizer"; }; 41 UsesFunctionLibrary()42 bool UsesFunctionLibrary() const override { return true; } 43 Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)44 Status Optimize(Cluster* cluster, const GrapplerItem& item, 45 GraphDef* optimized_graph) override { 46 GrapplerItem copy(item); 47 return OptimizeConsumeItem(cluster, std::move(copy), optimized_graph); 48 } 49 50 Status OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item, 51 GraphDef* optimized_graph); 52 53 string GetResultString() const; 54 55 void PrintResult(); 56 Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)57 void Feedback(Cluster* cluster, const GrapplerItem& item, 58 const GraphDef& optimized_graph, double result) override {} 59 60 private: 61 std::unique_ptr<GraphOptimizer> MakeNewOptimizer( 62 const string& optimizer) const; 63 64 // When grappler should lower control flow to V1 switch/merge style nodes. 65 bool LowerControlFlow() const; 66 67 // Initialize active optimizers from RewriterConfig toggles. 68 Status InitializeOptimizers( 69 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; 70 // Initialize active optimizers from RewriterConfig optimizer names. 71 Status InitializeOptimizersByName( 72 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; 73 // Initialize active optimizers from RewriterConfig.custom_optimizers. 74 Status InitializeCustomGraphOptimizers( 75 const std::set<string>& pre_initialized_optimizers, 76 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; 77 // Returns the config for a custom graph optimizer. Null if none was found. 78 const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig( 79 const string& name) const; 80 81 // Initialize active verifiers from the RewriterConfig toggles. 82 void InitializeVerifiers( 83 std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers, 84 std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers) 85 const; 86 87 // Run optimization pass over a single GrapplerItem. Meta optimizer might run 88 // multiple such passes: 1) for the main graph 2) for the function library 89 Status OptimizeGraph(Cluster* cluster, GrapplerItem&& item, 90 GraphDef* optimized_graph); 91 92 DeviceBase* const cpu_device_; // may be NULL 93 ConfigProto config_proto_; 94 RewriterConfig& cfg_; 95 96 struct OptimizerResult { 97 string optimizer_name; 98 string message; 99 Status status; 100 }; 101 102 struct GraphOptimizationResult { GraphOptimizationResultGraphOptimizationResult103 explicit GraphOptimizationResult(const string& id) : id(id) {} 104 string id; 105 std::vector<OptimizerResult> results; 106 }; 107 108 Status RunOptimizer(GraphOptimizer* optimizer, Cluster* cluster, 109 GrapplerItem* optimized_item, GraphDef* optimized_graph, 110 GraphOptimizationResult* optimization_result); 111 112 std::vector<GraphOptimizationResult> optimization_results_; 113 }; 114 115 bool MetaOptimizerEnabled(const ConfigProto& cfg); 116 117 // Run the meta optimizer. 118 // 119 // If <cpu_device> is non-null, it is the device to be used for executing ops 120 // during constant folding; if NULL, a new device is created for doing constant 121 // folding. For performance, it is recommended to pass in an existing cpu_device 122 // when possible. 123 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg, 124 DeviceBase* cpu_device, Cluster* cluster, 125 GraphDef* optimized_graph); 126 127 // Wrapper around RunMetaOptimizer convenient for optimizing 128 // function graphs. 129 // 130 // Runs grappler optimizations on `g` based on `config_proto`. 131 // `ret_node_names`: a vector of node names whose outputs are returned, 132 // aka fetches. when `g` represent a function, these are _Retval nodes. 133 // `lib`: function library to use with `g`. 134 // `device_set`: the set of devices that graph can refer to. 135 // `cpu_device`: the CPU device. 136 // `config_proto`: Grapper configuration. 137 // `grappler_item_id': Grappler item id (e.g. optimized function name). 138 // `optimization_options`: Grappler optimization constraints that are known only 139 // at runtime. 140 // 141 // **g is a graph constructed based on the runtime library 'lib'. 142 // OptimizeGraph mutates **g extensively and replaces '*g' with a 143 // complete copy. Therefore, the caller should not keep any references 144 // to nodes *g. 145 Status OptimizeGraph( 146 std::vector<string> ret_node_names, std::vector<string> keep_node_names, 147 FunctionLibraryDefinition* lib, const DeviceSet& device_set, 148 Device* cpu_device, const ConfigProto& config_proto, 149 const string& grappler_item_id, 150 const GrapplerItem::OptimizationOptions& optimization_options, 151 std::unique_ptr<tensorflow::Graph>* g); 152 153 } // namespace grappler 154 } // namespace tensorflow 155 156 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_META_OPTIMIZER_H_ 157