• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_
17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_
18 
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include "absl/strings/str_cat.h"
23 #include "tensorflow/core/grappler/costs/graph_properties.h"
24 #include "tensorflow/core/grappler/grappler_item.h"
25 #include "tensorflow/core/grappler/utils.h"
26 #include "tensorflow/core/lib/gtl/flatset.h"
27 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
28 
29 namespace tensorflow {
30 namespace grappler {
31 
32 struct NodeScopeAndName {
33   string scope;
34   string name;
35 };
36 
37 // Parse scope and name: "a/b/c/Add_1" -> {"a/b/c", "Add_1"}
38 const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
39 
40 // Context owned by GraphOptimizer, and passed to every stage at construction
41 // time. Each optimizer stage is responsible for updating it according to the
42 // changes it made to the graph.
43 //
44 // If an optimizer needs access to some helper class that is not present in this
45 // context, consider creating an extension context, specific to that
46 // optimizer (see example of ArithmeticOptimizerContext). GraphOptimizerContext
47 // should only have members that are useful to almost all optimizers.
48 struct GraphOptimizerContext {
GraphOptimizerContextGraphOptimizerContext49   GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
50                         GraphDef* optimized_graph,
51                         GraphProperties* graph_properties, NodeMap* node_map,
52                         gtl::FlatSet<string>* feed_nodes,
53                         RewriterConfig::Toggle opt_level)
54       : nodes_to_preserve(nodes_to_preserve),
55         optimized_graph(optimized_graph),
56         graph_properties(graph_properties),
57         node_map(node_map),
58         feed_nodes(feed_nodes),
59         opt_level(opt_level) {}
60 
61   const std::unordered_set<string>* nodes_to_preserve;
62   GraphDef* optimized_graph;
63   GraphProperties* graph_properties;
64   NodeMap* node_map;
65   gtl::FlatSet<string>* feed_nodes;
66   RewriterConfig::Toggle opt_level;
67 };
68 
69 Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
70                     NodeDef** node);
71 Status GetTensorProperties(const GraphOptimizerContext& ctx,
72                            const string& tensor,
73                            OpInfo::TensorProperties* properties);
74 
75 NodeDef* AddCopyNode(const GraphOptimizerContext& ctx, const string& name,
76                      const NodeDef* node_to_copy);
77 NodeDef* AddEmptyNode(const GraphOptimizerContext& ctx, const string& name);
78 
79 // WARNING:
80 // Optimizer stage must try to re-use original nodes of a graph and
81 // make all updates in place. This helps to make robust node placement
82 // decisions. Create new nodes only if there is a reason for that.
83 
84 // Make a name for a new node obtained by optimizing a single node of the
85 // original graph. The optimized node is placed under the original node scope.
86 //
87 // Node name uniqueness is guaranteed by unique name of an original node in
88 // a same scope.
89 //
90 // Empty sub_scope or prefix ignored. At least one of them must be non-empty.
91 //
92 // Example: a/b/c/Add -> a/b/c/${sub_scope}/${prefix}_Add.
93 const string MakeOptimizedNodeName(const NodeScopeAndName& node,
94                                    const string& sub_scope,
95                                    const string& prefix);
96 // Make a name for a new node obtained by optimizing multiple nodes of the
97 // original graph, starting from "root". The optimized node is placed under
98 // the original scope of a "root" node.
99 //
100 // Example: [a/b/c/Add, x/y/z/Mul] -> a/b/c/${sub_scope}/${prefix}_Add_Mul
101 const string MakeOptimizedNodeName(const NodeScopeAndName& root,
102                                    const std::vector<string> node_names,
103                                    const string& sub_scope,
104                                    const string& prefix);
105 
106 // Base class for multi-stage GraphOptimizers (ArithmeticOptimizer, etc...).
107 //
108 // If a graph optimizer consists of large number of small independent
109 // rewrites, each of them should be implemented as a separate stage.
110 //
111 // * Result:
112 // Each graph optimizer choose what result is reported by each stage
113 // (e.g. each stage can fill in the name of optimized nodes, or have more
114 // complex result).
115 template <typename Result>
116 class GraphOptimizerStage {
117  public:
GraphOptimizerStage(const string & optimizer_name,const string & stage_name,const GraphOptimizerContext & ctx)118   explicit GraphOptimizerStage(const string& optimizer_name,
119                                const string& stage_name,
120                                const GraphOptimizerContext& ctx)
121       : optimizer_name_(optimizer_name), stage_name_(stage_name), ctx_(ctx) {}
122   virtual ~GraphOptimizerStage() = default;
123 
stage_name()124   const string& stage_name() const { return stage_name_; }
optimizer_name()125   const string& optimizer_name() const { return optimizer_name_; }
126 
127   // Check if we should try to simplify node. Returning true doesn't
128   // guarantee that node will be simplified.
129   //
130   // Should implement just a basic sanity check, without any expensive graph
131   // traversals.
132   virtual bool IsSupported(const NodeDef* node) const = 0;
133 
134   // Try to simplify the given node.
135   //
136   // Return error status only if some precondition is failed, or got an
137   // incorrect graph. In every other case return Status:OK(), even if didn't
138   // simplify anything.
139   //
140   // Report result using output argument. Each GraphOptimizer can choose it's
141   // own Result type.
142   // TODO(ezhulenev): if it will appear that Result output parameter is not
143   // sufficiently useful (used with a reason by most optimizers), get rid of it,
144   // and remove template parameter.
145   virtual Status TrySimplify(NodeDef* node, Result* result) = 0;
146 
147   // Return InvalidArgumentError if node is not supported by the optimizer
148   // stage.
149   // TODO(ezhulenev): make this check part of non-virtual public API
150   // (TrySimplify), and make virtual implementation protected.
EnsureNodeIsSupported(const NodeDef * node)151   Status EnsureNodeIsSupported(const NodeDef* node) const {
152     return IsSupported(node)
153                ? Status::OK()
154                : errors::InvalidArgument(
155                      "Node ", node->name(), " is not supported by optimizer ",
156                      optimizer_name_, " and stage ", stage_name_);
157   }
158 
159   // Get a name for a new node, created by this stage, based on one or multiple
160   // nodes of an original graph.
OptimizedNodeName(const NodeScopeAndName & node)161   const string OptimizedNodeName(const NodeScopeAndName& node) const {
162     return MakeOptimizedNodeName(node, optimizer_name_, stage_name_);
163   }
OptimizedNodeName(const NodeScopeAndName & root,const std::vector<string> & nodes)164   const string OptimizedNodeName(const NodeScopeAndName& root,
165                                  const std::vector<string>& nodes) const {
166     return MakeOptimizedNodeName(root, nodes, optimizer_name_, stage_name_);
167   }
OptimizedNodeName(const NodeScopeAndName & node,const string & rewrite_rule)168   const string OptimizedNodeName(const NodeScopeAndName& node,
169                                  const string& rewrite_rule) const {
170     const string prefix = strings::StrCat(stage_name_, "_", rewrite_rule);
171     return MakeOptimizedNodeName(node, optimizer_name_, prefix);
172   }
173 
UniqueOptimizedNodeName(const NodeScopeAndName & node)174   const string UniqueOptimizedNodeName(const NodeScopeAndName& node) {
175     const string node_name = OptimizedNodeName(node);
176     return UniqueNodeName(node_name);
177   }
UniqueOptimizedNodeName(const NodeScopeAndName & node,const string & rewrite_rule)178   const string UniqueOptimizedNodeName(const NodeScopeAndName& node,
179                                        const string& rewrite_rule) {
180     const string node_name = OptimizedNodeName(node, rewrite_rule);
181     return UniqueNodeName(node_name);
182   }
183 
184   // Get a node by input name from a node map. Return an error if node was not
185   // found.
GetInputNode(const string & input,NodeDef ** node)186   Status GetInputNode(const string& input, NodeDef** node) const {
187     return ::tensorflow::grappler::GetInputNode(ctx_, input, node);
188   }
189   // Lookup tensor properties by name. Tensor name might have non-zero port
190   // number. Return an error if tensor node doesn't exists in a graph, or it
191   // doesn't have properties defined for requested port.
GetTensorProperties(const string & tensor,OpInfo::TensorProperties * properties)192   Status GetTensorProperties(const string& tensor,
193                              OpInfo::TensorProperties* properties) const {
194     return ::tensorflow::grappler::GetTensorProperties(ctx_, tensor,
195                                                        properties);
196   }
197 
AddCopyNode(const string & name,const NodeDef * node_to_copy)198   NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
199     return ::tensorflow::grappler::AddCopyNode(ctx_, name, node_to_copy);
200   }
AddEmptyNode(const string & name)201   NodeDef* AddEmptyNode(const string& name) {
202     return ::tensorflow::grappler::AddEmptyNode(ctx_, name);
203   }
204 
205  protected:
ctx()206   const GraphOptimizerContext& ctx() const { return ctx_; }
207 
208  private:
UniqueNodeName(absl::string_view name)209   const string UniqueNodeName(absl::string_view name) {
210     string node_name = string(name);
211     while (ctx_.node_map->NodeExists(node_name)) {
212       node_name = absl::StrCat(name, "_unique",
213                                optimized_node_name_counter_.fetch_add(1));
214     }
215 
216     return node_name;
217   }
218 
219   const string optimizer_name_;
220   const string stage_name_;
221   const GraphOptimizerContext ctx_;
222   std::atomic<int64> optimized_node_name_counter_ = {0};
223 };
224 
225 template <typename Result>
226 class GraphOptimizerStagePipeline {
227  public:
228   // Break predicate specifies if a pipeline should stop early, and not pass
229   // a node to the next registered optimizer stage, typically that should be the
230   // case when a stage successfully optimized a node, and it wants to yield
231   // control to the optimizer.
GraphOptimizerStagePipeline(const std::function<bool (const Result &)> break_predicate)232   explicit GraphOptimizerStagePipeline(
233       const std::function<bool(const Result&)> break_predicate)
234       : break_predicate_(break_predicate) {}
235 
236   // Add a stage to the pipeline. It should be called with the arguments for the
237   // stage constructor:
238   //
239   //   pipeline.AddStage<FooStage>(constructor_arg1, constructor_arg2);
240   //
241   // Returns a reference to the added stage.
242   template <typename T, typename... Args>
AddStage(Args &&...args)243   T& AddStage(Args&&... args) {
244     auto stage = new T(std::forward<Args>(args)...);
245     stages_.push_back(std::unique_ptr<T>(stage));
246     return *stage;
247   }
248 
249   // Pass a node through all registered optimizer stages, until break predicate
250   // is true.
251   //
252   // Return true, if pipeline exited after a break predicate was evaluated as
253   // 'true', which typically means that a node was optimized by one of the
254   // registered stages.
255   //
256   // Return false, if node was not optimized by any of registered stages.
PassThroughAllStages(NodeDef * node,Result * result)257   bool PassThroughAllStages(NodeDef* node, Result* result) {
258     for (auto& stage : stages_) {
259       if (stage->IsSupported(node)) {
260         const Status stage_status = stage->TrySimplify(node, result);
261         // Each stage must be "error safe" (just like exception safe). In
262         // case of any error it must leave optimized graph unmodified.
263         if (!stage_status.ok()) {
264           VLOG(2) << "Failed to run optimizer " << stage->optimizer_name()
265                   << ", stage " << stage->stage_name() << " node "
266                   << node->name()
267                   << ". Error: " << stage_status.error_message();
268         }
269         if (break_predicate_(*result)) return true;
270       }
271     }
272     return false;
273   }
274 
275   // Pass a node through all registered optimizer stages, until break predicate
276   // is true or a stage fails.
277   //
278   // Returns any stage failure status, or else Status::OK().
PassThroughAllStagesWithStatus(NodeDef * node,Result * result)279   Status PassThroughAllStagesWithStatus(NodeDef* node, Result* result) {
280     for (auto& stage : stages_) {
281       if (!stage->IsSupported(node)) {
282         continue;
283       }
284       const Status stage_status = stage->TrySimplify(node, result);
285       if (!stage_status.ok()) {
286         return stage_status;
287       } else if (break_predicate_(*result)) {
288         break;
289       }
290     }
291     return Status::OK();
292   }
293 
NumStages()294   std::size_t NumStages() { return stages_.size(); }
295 
StageNames()296   std::vector<string> StageNames() {
297     std::vector<string> names;
298     for (const auto& stage : stages_) {
299       names.push_back(stage->stage_name());
300     }
301     return names;
302   }
303 
304  private:
305   std::vector<std::unique_ptr<GraphOptimizerStage<Result>>> stages_;
306   std::function<bool(const Result&)> break_predicate_;
307 
308   TF_DISALLOW_COPY_AND_ASSIGN(GraphOptimizerStagePipeline);
309 };
310 
311 }  // end namespace grappler
312 }  // end namespace tensorflow
313 
314 #endif  // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_STAGE_H_
315