• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
18 
19 // Operations calling functions are becoming ubiquitous in TF 2.0.
20 // Examples include PartitionedCallOp, functional If/While, and Dataset ops.
21 // Such operations might require deep inspection - looking at the body of the
22 // called function - to place them and surrounding ops correctly.
23 
24 // This file contains some utilities for placer to correctly place such ops
25 // including:
26 // - PlacerInspectionRequiredOpChecker: A simple class with a single
27 // IsPlacerInspectionRequired method.
28 // - IsolatePlacerInspectionRequiredOps: This function adds Identity ops for
29 // each input/output of ops requiring placer inspection. It greatly simplifies
30 // the implementation of placing such ops.
31 
32 #include <vector>
33 
34 #include "absl/types/optional.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/graph/graph.h"
37 #include "tensorflow/core/lib/core/status.h"
38 
39 namespace tensorflow {
40 
41 // PlacerInspectionRequiredOpChecker allows one to check if Placer needs to
42 // look deeply into the op to place ops consuming the outputs correctly.
43 //
44 // It is a class instead of a standalone method because checking whether
45 // a function returns a resource takes non-trivial time and we cache the
46 // results.
47 class PlacerInspectionRequiredOpChecker {
48  public:
49   // Calls the constructor below with flib_def = graph->flib_def().
50   explicit PlacerInspectionRequiredOpChecker(const Graph* graph);
51   // Constructs a PlacerInspectionRequiredOpChecker for nodes of `graph`.
52   // The functions referenced by nodes in `graph` will be looked up in
53   // `flib_def`
54   PlacerInspectionRequiredOpChecker(const Graph* graph,
55                                     const FunctionLibraryDefinition* flib_def);
56 
57   // If `node` is considered a deep op, sets `*is_deep` to true and returns
58   // Status::OK(). If an error occurs, returns that error, and the value of
59   // `*is_deep` is undefined.
60   // Currently, an op is considered deep, if it is a calling a function
61   // returning a resource. This definition is driven by Placer's need to
62   // look inside the op.
63   // REQUIRES: `node` is part of `graph` passed into constructor.
64   Status IsPlacerInspectionRequired(const Node& node, bool* is_deep);
65 
66  private:
67   const Graph& graph_;
68   const FunctionLibraryDefinition& flib_def_;
69   // Indexed by the node id.
70   // If cache_[node_id] is empty, the deepness of the node with id `node_id` has
71   // not been computed yet. Else, it contains the value already computed.
72   std::vector<absl::optional<bool>> cache_;
73 };
74 
75 // Extracts `fdef` and `func` from `flib_def` for the function identified
76 // in "f" attribute of `node`.
77 Status GetFunctionDefAndAttrs(const FunctionLibraryDefinition& flib_def,
78                               const Node& node, const FunctionDef** fdef,
79                               NameAttrList* func);
80 
81 // The "call" stack of functions.
82 // Useful for better error messages as well as for detecting recursion.
83 // Stores references to graph nodes. These references must outlive this.
84 class FunctionStack {
85  public:
86   explicit FunctionStack(const string& function_name);
87 
88   // `node_in_current_function` must outlive this.
89   FunctionStack Push(const Node* node_in_current_function,
90                      const string& new_current_function) const;
91 
92   // Returns true iff this stack already includes `function_name`.
93   bool HasFunction(const string& function_name) const;
94 
current_function_name()95   const string& current_function_name() const { return current_function_name_; }
96 
97   // Format's this suitable for error interpolation that retrieves
98   // Python files and line numbers.
99   string FormatForError() const;
100 
101  private:
102   struct Frame {
FrameFrame103     Frame(const string& function, const Node* node)
104         : function_name(function), node(node) {}
105 
106     string function_name;
107     const Node* node;
108   };
109 
110   // The function at the top of the stack. In other words, the function
111   // that is currently being inspected for placement.
112   string current_function_name_;
113 
114   // The stack of frames that got the placement to the current_function_name_.
115   // frames_[0].function_name is the top function that Placer was constructed
116   // with. frames_[0].function_name can be empty if placer was constructed with
117   // a nameless graph, not a function.  frames_[0].node_name is a name of a node
118   // in frames_[0].function_name that required deep inspection (e.g. a
119   // PartitionedCallOp). The function that this node invoked is
120   // frames_[1].function_name, if frames_.size() > 1.  Else, the function that
121   // this node invoked is current_function_name_.
122   std::vector<Frame> frames_;
123 };
124 
125 // Adds Identities for each input and output of function-calling ops in `graph`
126 //
127 // For example, the following graph calling a function on inputs `a` and `b`
128 // and producing output `y` will be rewritten to include identities on all
129 // edges:
130 //
131 //      a             b
132 //      |             |
133 //      v             v
134 //    f (PartitionedCallOp)
135 //         |
136 //         v
137 //         y
138 //
139 // is transformed to
140 //
141 //      a             b
142 //      |             |
143 //  a_f (Identity)   b_f (Identity)
144 //      |             |
145 //      v             v
146 //    f (PartitionedCallOp)
147 //         |
148 //      f_y (Identity)
149 //         |
150 //         v
151 //         y
152 //
153 Status IsolatePlacerInspectionRequiredOps(
154     const FunctionLibraryDefinition& flib_def, Graph* graph);
155 
156 }  // namespace tensorflow
157 
158 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_INSPECTION_REQUIRED_OPS_UTILS_H_
159