• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/graph_optimizer.h"
25 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/graph/graph.h"
28 #include "tensorflow/core/protobuf/config.pb.h"
29 
30 namespace tensorflow {
31 
32 static constexpr const char* const kNoInlineAttr = "_noinline";
33 
34 // Registers a default customizable kernel creator for a function call.
35 //
36 // If 'cb()' returns a non-OK, we still fall back to an executor-based
37 // interpreter op kernel to execute a function. If 'cb()' returns OK,
38 // takes ownership of the returned OpKernel.
39 //
40 // TODO(zhifengc/phawkins): b/32379046
41 void RegisterDefaultCustomKernelCreator(CustomKernelCreator cb);
42 
43 // Creates a FunctionLibraryRuntime, which instantiates functions
44 // defined in "lib_def" and executes functions on the "device".
45 // "device_mgr" must contain the "device". If not nullptr,
46 // "custom_kernel_creator" is consulted by the returned runtime to
47 // create kernels.
48 //
49 // The returned object does not take ownerships of "device" or
50 // "lib_def".  The caller must ensure "device" and "lib_def" outlives
51 // the returned object.
52 //
53 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that
54 // typically owns the created FunctionLibraryRuntime object. The parent pointer
55 // is not owned by the FunctionLibraryRuntime object.
56 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
57     const DeviceMgr* device_mgr, Env* env, Device* device,
58     int graph_def_version, const FunctionLibraryDefinition* lib_def,
59     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
60     CustomKernelCreator custom_kernel_creator,
61     ProcessFunctionLibraryRuntime* parent);
62 
63 // Same as above except that the returned runtime consults with the
64 // global default custom kernel creator registered by
65 // RegisterDefaultCustomKernelCreator.
66 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
67     const DeviceMgr* device_mgr, Env* env, Device* device,
68     int graph_def_version, const FunctionLibraryDefinition* lib_def,
69     thread::ThreadPool* thread_pool, const OptimizerOptions& optimizer_options,
70     ProcessFunctionLibraryRuntime* parent);
71 
72 // FunctionLibraryRuntime::GetFunctionBody returns a description of an
73 // instantiated function that is represented as a Graph with arg/ret
74 // nodes annotated.
75 struct FunctionBody {
76   FunctionDef fdef;
77   Graph* graph = nullptr;  // owned.
78   DataTypeVector arg_types;
79   DataTypeVector ret_types;
80   gtl::InlinedVector<Node*, 4> arg_nodes;
81   gtl::InlinedVector<Node*, 4> ret_nodes;
82   gtl::InlinedVector<Node*, 4> control_ret_nodes;
83 
FunctionBodyFunctionBody84   FunctionBody() {}
85   FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
86                DataTypeSlice ret_types, Graph* g);
87   ~FunctionBody();
88 };
89 
90 // Debugging facility.  Returns a debug string for a graph
91 // representing an instantiated function.
92 string DebugString(const Graph* instantiated_func_graph);
93 
94 // A few hand-crafted optimization on the instantiated function body
95 // (a Graph*).
96 
97 // Removes nodes that are
98 //   1. not stateful; and
99 //   2. not _Arg; and
100 //   3. not reachable from _Retval.
101 //
102 // This function is triggered by function inlining, unlike 'PruneFunctionBody'
103 // it doesn't preserve nodes that are reachable from control returns. Function
104 // inlining is responsible for connecting control return nodes with the nodes
105 // that have input control edges from the inlined function call node.
106 //
107 // Assuming that automatic control dependency tracking is correct, absence of
108 // outgoing control edge from the function call node means that no one needs to
109 // observe side-effect that might have been generated by the function (see
110 // documentation in common_runtime/function.cc for details).
111 //
112 // Returns true iff any node is removed from "g".
113 bool RemoveDeadNodes(Graph* g);
114 
115 // Find a pattern:
116 //   src -(in)-> node -(out)-> dst, where
117 // 1) node is an identity node;
118 // 2) in is the only incoming data edge;
119 // 3) out is the only outgoing data edge;
120 //
121 // Rewrites the above pattern with src->dst and relevant data
122 // dependencies updated. Repeat the process until no such pattern
123 // left.
124 bool RemoveIdentityNodes(Graph* g);
125 
126 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
127 bool RemoveListArrayConverter(Graph* g);
128 
129 // Dump the contents of the "graph" to log files if the logging level is
130 // sufficiently high.
131 void DumpGraph(StringPiece label, const Graph* g);
132 
133 // Applies graph rewrite optimization such as inlining, dead code
134 // removal, etc.
135 //
136 // **g is a graph constructed based on the runtime library 'lib'.
137 // OptimizeGraph mutates **g extensively and replaces '*g' with a
138 // complete copy. Therefore, the caller should not keep any references
139 // to nodes *g.
140 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
141                    const GraphOptimizer::Options& graph_optimizer_options);
142 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
143 
144 // Convert the Graph of a function to a GraphDef.
145 //
146 // Handles renaming of nodes to avoid duplicate names which may
147 // be present after various rewriting operations.
148 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
149 
150 // Given a numerical function "f", returns another numerical function
151 // "g", such that if "f" takes N inputs and produces M outputs, "g"
152 // takes N + M inputs and produces N outputs. I.e., if
153 //   (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
154 // g is a function which is
155 //   (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
156 //                                     dL/dy1, dL/dy2, ..., dL/dy_M),
157 // where L is a scalar-value function of (...x_i...).
158 //
159 // TODO(zhifengc): Asks math expert to say the comment again.
160 FunctionBody* SymbolicGradient(const FunctionBody& f);
161 
162 struct InlineFunctionBodyOptions {
163   // All nodes that have incoming control edge *from* the function call node,
164   // will be forwarded to the "output control node". There are two options for
165   // choosing which nodes will have a control edge *to* the "output control
166   // node":
167   //   a) control returns            (`control_ret` field in FunctionDef)
168   //   b) data returns               (`ret` field in FunctionDef)
169   enum class OutputControlSource { kDataOutputs, kControlOutputs };
170 
171   // Ignore '_noinline' function attribute.
172   bool ignore_noinline = false;
173   // If 'true' function inlining will override explicitly specified devices
174   // inside function body with the caller node device.
175   bool override_device = false;
176   // For compatibility with Tensorflow v1 by default we will use data outputs.
177   // Control returns were added to Tensorflow v2 with automatic control
178   // dependencies tracking in Eager mode.
179   OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
180 
181   // A human-readable debug string for this options.
182   string DebugString() const;
183 };
184 
185 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
186 // based on the type signature of 'node' and 'fbody':
187 //
188 // (1) Caller node has the same number of inputs and outputs as the function.
189 // (2) Caller node inputs and outputs have the same data types as function
190 //     inputs and returns.
191 // (3) Validation rules defined in InlineFunctionBodyOptions.
192 //
193 // If function can't be safely inlined, returns error message with details why
194 // inlining is not possible or safe.
195 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
196                         const InlineFunctionBodyOptions& options);
197 
198 // Given a "caller" in graph "g", which is a function call of a function
199 // to "fbody". Replaces the "caller" with fbody->graph and connects
200 // edges properly. "override_device" specifies whether inlining should replace
201 // explicitly specified devices inside fbody with the callee's device.
202 //
203 // Returns 'Status::OK()' if function was successfully inlined into the graph.
204 // If function inlining is not possible returns a error with a reason, and
205 // leaves the graph in unmodified state.
206 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
207                           Node* caller, const FunctionBody* fbody,
208                           const InlineFunctionBodyOptions& options);
209 
210 // There are three types of function calls that could be invoked during
211 // *Tensorflow graph execution*:
212 //
213 // 1) Native function call (node.type_string() is the function name). These
214 //    functions are always executed on a single-device, which is the device of
215 //    the function call node.
216 //
217 // 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
218 //    ops) can execute on multiple devices and accept DT_RESOURCE inputs that
219 //    belong to different devices. This type of functions was added in
220 //    Tensorflow 2.0 Eager mode, and it has control outputs to represent
221 //    side-effects that must always execute (see `control_ret` in FunctionDef).
222 //
223 // 3) SymbolicGradient has been deprecated for a while, but we still keep it and
224 //    use `native` options for inlining for compatibility.
225 //
226 // We need to have distinct inlining rules for compatibility with Tensorflow v1.
227 //
228 // There are few other places in Tensorflow that could execute functions:
229 //
230 // 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
231 //    functions directly via function library runtime, without going through
232 //    the graph.
233 // 2) tf.data pipelines - also execute functions directly via function library
234 //    runtime with custom executors.
235 struct ExpandInlineFunctionsOptions {
ExpandInlineFunctionsOptionsExpandInlineFunctionsOptions236   ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
237     using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
238     multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
239   }
240 
241   InlineFunctionBodyOptions native_options;
242   InlineFunctionBodyOptions multi_device_options;
243 };
244 
245 // WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
246 // workaround that will be enabled only during the function inlining unification
247 // (b/126811947). Contact ezhulenev@ if you think you need it.
248 // TODO(ezhulenev): Delete this function.
249 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
250                            const ExpandInlineFunctionsOptions& options);
251 
252 // For each node in "graph", if "lib" indicates that the node is a
253 // function call, inline the function body. Returns true if at least
254 // one node is inlined.
255 //
256 // This routine goes through "graph" nodes once and applies the
257 // inlining. The caller may decide to apply the inlining on "graph"
258 // multiple times by calling ExpandInlineFunctions a few times.
259 //
260 // Function calls that can't be safely inlined into the graph (ValidateInlining
261 // returns error), are ignored.
262 //
263 // TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
264 // FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
265 // lower_function_call.cc).
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph)266 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
267   return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
268 }
269 
270 // Extracts function name and attributes from `call_def` and invokes
271 // flr->Instantiate(name, attrs, handle).
272 // `call_def` can be a native function call (where the op type is the function
273 // name) or a call through PartitionedCall/StatefulPartitionedCall.
274 Status InstantiateFunctionCall(const NodeDef& call_def,
275                                FunctionLibraryRuntime& flr,
276                                FunctionLibraryRuntime::Handle* handle);
277 
278 // Returns true iff `n` represents a function call. `n` can be a native
279 // function call (n.type_string() is the function name),
280 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
281 // has been deprecated for a while).
282 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
283 
284 // Instantiates FunctionDef into a graph. Set *fbody to point to the
285 // FunctionBody that holds the instantiated FunctionDef.
286 Status FunctionDefToBodyHelper(
287     const FunctionDef& fdef, const AttrSlice& attrs,
288     const FunctionLibraryDefinition* const lib_def,
289     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
290     FunctionBody** fbody);
291 }  // end namespace tensorflow
292 
293 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
294