1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ 18 19 #include <unordered_map> 20 #include <unordered_set> 21 22 #include "absl/strings/str_cat.h" 23 #include "tensorflow/core/grappler/costs/graph_properties.h" 24 #include "tensorflow/core/grappler/grappler_item.h" 25 #include "tensorflow/core/grappler/utils.h" 26 #include "tensorflow/core/lib/gtl/flatset.h" 27 #include "tensorflow/core/protobuf/rewriter_config.pb.h" 28 29 namespace tensorflow { 30 namespace grappler { 31 32 struct NodeScopeAndName { 33 string scope; 34 string name; 35 }; 36 37 // Parse scope and name: "a/b/c/Add_1" -> {"a/b/c", "Add_1"} 38 const NodeScopeAndName ParseNodeScopeAndName(const string& node_name); 39 40 // Context owned by GraphOptimizer, and passed to every stage at construction 41 // time. Each optimizer stage is responsible for updating it according to the 42 // changes it made to the graph. 43 // 44 // If an optimizer needs access to some helper class that is not present in this 45 // context, consider creating an extension context, specific to that 46 // optimizer (see example of ArithmeticOptimizerContext). GraphOptimizerContext 47 // should only have members that are useful to almost all optimizers. 48 struct GraphOptimizerContext { GraphOptimizerContextGraphOptimizerContext49 GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve, 50 GraphDef* optimized_graph, 51 GraphProperties* graph_properties, NodeMap* node_map, 52 gtl::FlatSet<string>* feed_nodes, 53 RewriterConfig::Toggle opt_level) 54 : nodes_to_preserve(nodes_to_preserve), 55 optimized_graph(optimized_graph), 56 graph_properties(graph_properties), 57 node_map(node_map), 58 feed_nodes(feed_nodes), 59 opt_level(opt_level) {} 60 61 const std::unordered_set<string>* nodes_to_preserve; 62 GraphDef* optimized_graph; 63 GraphProperties* graph_properties; 64 NodeMap* node_map; 65 gtl::FlatSet<string>* feed_nodes; 66 RewriterConfig::Toggle opt_level; 67 }; 68 69 Status GetInputNode(const GraphOptimizerContext& ctx, const string& input, 70 NodeDef** node); 71 Status GetTensorProperties(const GraphOptimizerContext& ctx, 72 const string& tensor, 73 const OpInfo::TensorProperties** properties); 74 75 NodeDef* AddCopyNode(const GraphOptimizerContext& ctx, const string& name, 76 const NodeDef* node_to_copy); 77 NodeDef* AddEmptyNode(const GraphOptimizerContext& ctx, const string& name); 78 79 // WARNING: 80 // Optimizer stage must try to re-use original nodes of a graph and 81 // make all updates in place. This helps to make robust node placement 82 // decisions. Create new nodes only if there is a reason for that. 83 84 // Make a name for a new node obtained by optimizing a single node of the 85 // original graph. The optimized node is placed under the original node scope. 86 // 87 // Node name uniqueness is guaranteed by unique name of an original node in 88 // a same scope. 89 // 90 // Empty sub_scope or prefix ignored. At least one of them must be non-empty. 91 // 92 // Example: a/b/c/Add -> a/b/c/${sub_scope}/${prefix}_Add. 93 const string MakeOptimizedNodeName(const NodeScopeAndName& node, 94 const string& sub_scope, 95 const string& prefix); 96 // Make a name for a new node obtained by optimizing multiple nodes of the 97 // original graph, starting from "root". The optimized node is placed under 98 // the original scope of a "root" node. 99 // 100 // Example: [a/b/c/Add, x/y/z/Mul] -> a/b/c/${sub_scope}/${prefix}_Add_Mul 101 const string MakeOptimizedNodeName(const NodeScopeAndName& root, 102 const std::vector<string> node_names, 103 const string& sub_scope, 104 const string& prefix); 105 106 // Base class for multi-stage GraphOptimizers (ArithmeticOptimizer, etc...). 107 // 108 // If a graph optimizer consists of large number of small independent 109 // rewrites, each of them should be implemented as a separate stage. 110 // 111 // * Result: 112 // Each graph optimizer choose what result is reported by each stage 113 // (e.g. each stage can fill in the name of optimized nodes, or have more 114 // complex result). 115 template <typename Result> 116 class GraphOptimizerStage { 117 public: GraphOptimizerStage(const string & optimizer_name,const string & stage_name,const GraphOptimizerContext & ctx)118 explicit GraphOptimizerStage(const string& optimizer_name, 119 const string& stage_name, 120 const GraphOptimizerContext& ctx) 121 : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {} 122 virtual ~GraphOptimizerStage() = default; 123 stage_name()124 const string& stage_name() const { return stage_name_; } optimizer_name()125 const string& optimizer_name() const { return optimizer_name_; } 126 127 // Check if we should try to simplify node. Returning true doesn't 128 // guarantee that node will be simplified. 129 // 130 // Should implement just a basic sanity check, without any expensive graph 131 // traversals. 132 virtual bool IsSupported(const NodeDef* node) const = 0; 133 134 // Try to simplify the given node. 135 // 136 // Return error status only if some precondition is failed, or got an 137 // incorrect graph. In every other case return Status:OK(), even if didn't 138 // simplify anything. 139 // 140 // Report result using output argument. Each GraphOptimizer can choose it's 141 // own Result type. 142 // TODO(ezhulenev): if it will appear that Result output parameter is not 143 // sufficiently useful (used with a reason by most optimizers), get rid of it, 144 // and remove template parameter. 145 virtual Status TrySimplify(NodeDef* node, Result* result) = 0; 146 147 // Return InvalidArgumentError if node is not supported by the optimizer 148 // stage. 149 // TODO(ezhulenev): make this check part of non-virtual public API 150 // (TrySimplify), and make virtual implementation protected. EnsureNodeIsSupported(const NodeDef * node)151 Status EnsureNodeIsSupported(const NodeDef* node) const { 152 return IsSupported(node) 153 ? Status::OK() 154 : errors::InvalidArgument( 155 "Node ", node->name(), " is not supported by optimizer ", 156 optimizer_name_, " and stage ", stage_name_); 157 } 158 159 // Get a name for a new node, created by this stage, based on one or multiple 160 // nodes of an original graph. OptimizedNodeName(const NodeScopeAndName & node)161 const string OptimizedNodeName(const NodeScopeAndName& node) const { 162 return MakeOptimizedNodeName(node, optimizer_name_, stage_name_); 163 } OptimizedNodeName(const NodeScopeAndName & root,const std::vector<string> & nodes)164 const string OptimizedNodeName(const NodeScopeAndName& root, 165 const std::vector<string>& nodes) const { 166 return MakeOptimizedNodeName(root, nodes, optimizer_name_, stage_name_); 167 } OptimizedNodeName(const NodeScopeAndName & node,const string & rewrite_rule)168 const string OptimizedNodeName(const NodeScopeAndName& node, 169 const string& rewrite_rule) const { 170 const string prefix = strings::StrCat(stage_name_, "_", rewrite_rule); 171 return MakeOptimizedNodeName(node, optimizer_name_, prefix); 172 } 173 UniqueOptimizedNodeName(const NodeScopeAndName & node)174 const string UniqueOptimizedNodeName(const NodeScopeAndName& node) { 175 const string node_name = OptimizedNodeName(node); 176 return UniqueNodeName(node_name); 177 } UniqueOptimizedNodeName(const NodeScopeAndName & node,const string & rewrite_rule)178 const string UniqueOptimizedNodeName(const NodeScopeAndName& node, 179 const string& rewrite_rule) { 180 const string node_name = OptimizedNodeName(node, rewrite_rule); 181 return UniqueNodeName(node_name); 182 } 183 184 // Get a node by input name from a node map. Return an error if node was not 185 // found. GetInputNode(const string & input,NodeDef ** node)186 Status GetInputNode(const string& input, NodeDef** node) const { 187 return ::tensorflow::grappler::GetInputNode(ctx_, input, node); 188 } 189 // Lookup tensor properties by name. Tensor name might have non-zero port 190 // number. Return an error if tensor node doesn't exists in a graph, or it 191 // doesn't have properties defined for requested port. GetTensorProperties(const string & tensor,const OpInfo::TensorProperties ** properties)192 Status GetTensorProperties( 193 const string& tensor, const OpInfo::TensorProperties** properties) const { 194 return ::tensorflow::grappler::GetTensorProperties(ctx_, tensor, 195 properties); 196 } 197 AddCopyNode(const string & name,const NodeDef * node_to_copy)198 NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) { 199 return ::tensorflow::grappler::AddCopyNode(ctx_, name, node_to_copy); 200 } AddEmptyNode(const string & name)201 NodeDef* AddEmptyNode(const string& name) { 202 return ::tensorflow::grappler::AddEmptyNode(ctx_, name); 203 } 204 205 protected: ctx()206 const GraphOptimizerContext& ctx() const { return ctx_; } 207 208 private: UniqueNodeName(absl::string_view name)209 const string UniqueNodeName(absl::string_view name) { 210 string node_name = string(name); 211 while (ctx_.node_map->NodeExists(node_name)) { 212 node_name = absl::StrCat(name, "_unique", 213 optimized_node_name_counter_.fetch_add(1)); 214 } 215 216 return node_name; 217 } 218 219 const string optimizer_name_; 220 const string stage_name_; 221 const GraphOptimizerContext ctx_; 222 std::atomic<int64> optimized_node_name_counter_ = {0}; 223 }; 224 225 template <typename Result> 226 class GraphOptimizerStagePipeline { 227 public: 228 // Break predicate specifies if a pipeline should stop early, and not pass 229 // a node to the next registered optimizer stage, typically that should be the 230 // case when a stage successfully optimized a node, and it wants to yield 231 // control to the optimizer. GraphOptimizerStagePipeline(const std::function<bool (const Result &)> break_predicate)232 explicit GraphOptimizerStagePipeline( 233 const std::function<bool(const Result&)> break_predicate) 234 : break_predicate_(break_predicate) {} 235 236 // Add a stage to the pipeline. It should be called with the arguments for the 237 // stage constructor: 238 // 239 // pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2); 240 // 241 // Returns a reference to the added stage. 242 template <typename T, typename... Args> AddStage(Args &&...args)243 T& AddStage(Args&&... args) { 244 auto stage = new T(std::forward<Args>(args)...); 245 stages_.push_back(std::unique_ptr<T>(stage)); 246 return *stage; 247 } 248 249 // Pass a node through all registered optimizer stages, until break predicate 250 // is true. 251 // 252 // Return true, if pipeline exited after a break predicate was evaluated as 253 // 'true', which typically means that a node was optimized by one of the 254 // registered stages. 255 // 256 // Return false, if node was not optimized by any of registered stages. PassThroughAllStages(NodeDef * node,Result * result)257 bool PassThroughAllStages(NodeDef* node, Result* result) { 258 for (auto& stage : stages_) { 259 if (stage->IsSupported(node)) { 260 const Status stage_status = stage->TrySimplify(node, result); 261 // Each stage must be "error safe" (just like exception safe). In 262 // case of any error it must leave optimized graph unmodified. 263 if (!stage_status.ok()) { 264 VLOG(2) << "Failed to run optimizer " << stage->optimizer_name() 265 << ", stage " << stage->stage_name() << " node " 266 << node->name() 267 << ". Error: " << stage_status.error_message(); 268 } 269 if (break_predicate_(*result)) return true; 270 } 271 } 272 return false; 273 } 274 275 // Pass a node through all registered optimizer stages, until break predicate 276 // is true or a stage fails. 277 // 278 // Returns any stage failure status, or else Status::OK(). PassThroughAllStagesWithStatus(NodeDef * node,Result * result)279 Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) { 280 for (auto& stage : stages_) { 281 if (!stage->IsSupported(node)) { 282 continue; 283 } 284 const Status stage_status = stage->TrySimplify(node, result); 285 if (!stage_status.ok()) { 286 return stage_status; 287 } else if (break_predicate_(*result)) { 288 break; 289 } 290 } 291 return Status::OK(); 292 } 293 NumStages()294 std::size_t NumStages() { return stages_.size(); } 295 StageNames()296 std::vector<string> StageNames() { 297 std::vector<string> names; 298 for (const auto& stage : stages_) { 299 names.push_back(stage->stage_name()); 300 } 301 return names; 302 } 303 304 private: 305 std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_; 306 std::function<bool(const Result&)> break_predicate_; 307 308 TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline); 309 }; 310 311 } // end namespace grappler 312 } // end namespace tensorflow 313 314 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_ 315