1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
18
19 #include <functional>
20 #include <memory>
21
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/protobuf/config.pb.h"
30
31 namespace tensorflow {
32
33 static constexpr const char* const kNoInlineAttr = "_noinline";
34
35 // Get default customizable kernel creator if set
36 const CustomKernelCreator* GetDefaultCustomKernelCreator();
37
38 // Registers a default customizable kernel creator for a function call.
39 //
40 // If c->CanCreateKernel returns false, we still fall back to an executor-based
41 // interpreter op kernel to execute a function. Else c->CreateKernel() can be
42 // used to create a kernel that will compile the function with XLA and run the
43 // resulting program.
44 //
45 // TODO(zhifengc/phawkins): b/32379046
46 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c);
47
48 // Creates a FunctionLibraryRuntime, which instantiates functions
49 // defined in "lib_def" and executes functions on the "device".
50 // "device_mgr" must contain the "device". If not nullptr,
51 // "custom_kernel_creator" is consulted by the returned runtime to
52 // create kernels.
53 //
54 // The returned object does not take ownerships of "device" or
55 // "lib_def". The caller must ensure "device" and "lib_def" outlives
56 // the returned object.
57 //
58 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that
59 // typically owns the created FunctionLibraryRuntime object. The parent pointer
60 // is not owned by the FunctionLibraryRuntime object.
61 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
62 const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
63 Device* device, int graph_def_version,
64 const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
65 const OptimizerOptions& optimizer_options,
66 const CustomKernelCreator* custom_kernel_creator,
67 const SessionMetadata* session_metadata,
68 ProcessFunctionLibraryRuntime* parent);
69
70 // FunctionLibraryRuntime::GetFunctionBody returns a description of an
71 // instantiated function that is represented as a Graph with arg/ret
72 // nodes annotated.
73 struct FunctionBody {
74 FunctionDef fdef;
75 Graph* graph = nullptr; // owned.
76 DataTypeVector arg_types;
77 DataTypeVector ret_types;
78 // arg_nodes[i] contains the i'th function input. In other words,
79 // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
80 gtl::InlinedVector<Node*, 4> arg_nodes;
81 // ret_nodes[i] contains the i'th function output. In other words,
82 // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
83 gtl::InlinedVector<Node*, 4> ret_nodes;
84 gtl::InlinedVector<Node*, 4> control_ret_nodes;
85
FunctionBodyFunctionBody86 FunctionBody() {}
87 FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
88 DataTypeSlice ret_types, Graph* g);
89 ~FunctionBody();
90 };
91
92 // Debugging facility. Returns a debug string for a graph
93 // representing an instantiated function.
94 string DebugString(const Graph* g);
95
96 // A few hand-crafted optimization on the instantiated function body
97 // (a Graph*).
98
99 // Removes nodes that are
100 // 1. not stateful; and
101 // 2. not _Arg; and
102 // 3. not reachable from _Retval.
103 //
104 // This function is triggered by function inlining, unlike 'PruneFunctionBody'
105 // it doesn't preserve nodes that are reachable from control returns. Function
106 // inlining is responsible for connecting control return nodes with the nodes
107 // that have input control edges from the inlined function call node.
108 //
109 // Assuming that automatic control dependency tracking is correct, absence of
110 // outgoing control edge from the function call node means that no one needs to
111 // observe side-effect that might have been generated by the function (see
112 // documentation in common_runtime/function.cc for details).
113 //
114 // Returns true iff any node is removed from "g".
115 bool RemoveDeadNodes(Graph* g);
116
117 // Find a pattern:
118 // src -(in)-> node -(out)-> dst, where
119 // 1) node is an identity node;
120 // 2) in is the only incoming data edge;
121 // 3) out is the only outgoing data edge;
122 //
123 // Rewrites the above pattern with src->dst and relevant data
124 // dependencies updated. Repeat the process until no such pattern
125 // left.
126 bool RemoveIdentityNodes(Graph* g);
127
128 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
129 bool RemoveListArrayConverter(Graph* g);
130
131 // Dump the contents of the "graph" to log files if the logging level is
132 // sufficiently high.
133 void DumpGraph(StringPiece label, const Graph* g);
134
135 // Applies graph rewrite optimization such as inlining, dead code
136 // removal, etc.
137 //
138 // **g is a graph constructed based on the runtime library 'lib'.
139 // OptimizeGraph mutates **g extensively and replaces '*g' with a
140 // complete copy. Therefore, the caller should not keep any references
141 // to nodes *g.
142 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
143 const GraphOptimizer::Options& graph_optimizer_options);
144 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
145
146 // Convert the Graph of a function to a GraphDef.
147 //
148 // Handles renaming of nodes to avoid duplicate names which may
149 // be present after various rewriting operations.
150 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
151
152 // Given a numerical function "f", returns another numerical function
153 // "g", such that if "f" takes N inputs and produces M outputs, "g"
154 // takes N + M inputs and produces N outputs. I.e., if
155 // (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
156 // g is a function which is
157 // (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
158 // dL/dy1, dL/dy2, ..., dL/dy_M),
159 // where L is a scalar-value function of (...x_i...).
160 //
161 // TODO(zhifengc): Asks math expert to say the comment again.
162 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
163
164 // Optionally override device assignment for nodes added to the graph for
165 // inlined functions:
166 // (1) Identity nodes added in place of function input arguments.
167 // (2) Identity nodes added in place of function return values.
168 // (3) Special NoOp nodes that enforce side-effects execution order.
169 // (4) All nodes inside function body specified in FunctionDef.
170 class InlinedFunctionBodyPlacer {
171 public:
172 virtual ~InlinedFunctionBodyPlacer() = default;
173
174 virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
175 virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
176 // Returns true if the added input/output identity nodes should be colocated
177 // with the corresponding input/output from the function body.
178 virtual bool ColocateInputOutputIdentities() const = 0;
179 virtual absl::optional<string> ControlNodeDevice() const = 0;
180 virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
181
182 // Place input nodes on the same device as the corresponding caller input
183 // node. Do not specify any placement for all other nodes.
184 static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
185 const Graph& graph, const Node& caller);
186
187 // Place all nodes on the same device as caller node.
188 static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
189 const Graph& graph, const Node& caller);
190
191 // Place input nodes on the same device as the corresponding caller input
192 // node. Do not place output node. Place control nodes on the same device as
193 // caller node. For all function body nodes overrides job, replica and task
194 // parts of the device assignment to match function caller node.
195 static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
196 const Graph& graph, const Node& caller);
197
198 using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
199 const Graph&, const Node&)>;
200
201 struct Config {
202 string name;
203 Factory get;
204 };
205
Default()206 static Config Default() { return {"default", DefaultPlacer}; }
SingleDevice()207 static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
MultiDevice()208 static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
209 };
210
211 struct InlineFunctionBodyOptions {
212 // All nodes that have incoming control edge *from* the function call node,
213 // will be forwarded to the "output control node". There are two options for
214 // choosing which nodes will have a control edge *to* the "output control
215 // node":
216 // a) control returns (`control_ret` field in FunctionDef)
217 // b) data returns (`ret` field in FunctionDef)
218 enum class OutputControlSource { kDataOutputs, kControlOutputs };
219
220 // Keep a node in a graph with the same name as the function call node:
221 //
222 // a) DoNotKeep: Function call node is fully inlined, and there is no node in
223 // a graph with the same name.
224 //
225 // b) Fetchable: Add an IdentityN node to the graph in place of the inlined
226 // function call node. It will have a control edge from inlined
227 // 'output_control_node' and data edges from function output nodes.
228 // The IdentityN node will be placed on the same device as the caller node.
229 //
230 // This is mostly for compatibility with Tensorflow v1 and sessions.
231 // When we prepare a graph for execution in
232 // GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
233 // fetched, so we can't safely remove any of them. When graph executed as a
234 // function it has 'Retval' nodes for all fetched tensors, and we can
235 // safely inline function calls.
236 //
237 // c) Targetable: Add a NoOp node to the graph in place of the inlined
238 // function call node. It will have a control edge from inline
239 // 'output_control_node' and no data edges. NoOp node will be placed on the
240 // same device as the caller node. This will keep the inlined function call
241 // node a valid 'session.run' target, and also will keep it a valid control
242 // output node.
243 enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
244
245 // If 'true' function inlining is completely disabled. This allows to control
246 // function inlining for different types of function calls (see
247 // 'ExpandInlineFunctionsOptions' below).
248 bool disable_inlining = false;
249 // Ignore '_noinline' function attribute.
250 bool ignore_noinline = false;
251 // If 'true' function inlining will inline functions in implementation
252 // selection group. Normally those functions should not be inlined; they will
253 // be handled by Grappler.
254 bool inline_impl_selection_group_functions = false;
255 // Controls if we want to keep a node with the name as the function call node
256 // in a graph after function inlining.
257 KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
258 // For compatibility with Tensorflow v1 by default we will use data outputs.
259 // Control returns were added to Tensorflow v2 with automatic control
260 // dependencies tracking in Eager mode.
261 OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
262 // Inlined function body placer decides what requested device assignments
263 // should be added to the nodes added to the graph. See documentation above
264 // for available strategies.
265 InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
266 InlinedFunctionBodyPlacer::Default();
267 // If true, frame names in the function body will be
268 // made unique in the resulting graph (e.g. by prepending a unique prefix).
269 // NOTE(mrry): Only set this option to false when there is a single function
270 // call in the graph (e.g. when making a remote function call via
271 // ClusterFunctionLibraryRuntime). This option is provided because the graph
272 // partitioner generates frame names that must remain unmodified across all
273 // partitions of a multi-device function.
274 bool uniquify_frame_names = true;
275
276 // A human-readable debug string for this options.
277 string DebugString() const;
278 };
279
280 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
281 // based on the type signature of 'node' and 'fbody':
282 //
283 // (1) Caller node has the same number of inputs and outputs as the function.
284 // (2) Caller node inputs and outputs have the same data types as function
285 // inputs and returns.
286 // (3) Validation rules defined in InlineFunctionBodyOptions.
287 //
288 // If function can't be safely inlined, returns error message with details why
289 // inlining is not possible or safe.
290 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
291 const InlineFunctionBodyOptions& options);
292
293 // Given a "caller" in graph "g", which is a function call of a function
294 // to "fbody". Replaces the "caller" with fbody->graph and connects
295 // edges properly. "override_device" specifies whether inlining should replace
296 // explicitly specified devices inside fbody with the callee's device.
297 //
298 // Returns 'Status::OK()' if function was successfully inlined into the graph.
299 // If function inlining is not possible returns an error with a reason, and
300 // leaves the graph in unmodified state.
301 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
302 Node* caller, const FunctionBody* fbody,
303 const InlineFunctionBodyOptions& options);
304
305 // There are three types of function calls that could be invoked during
306 // *Tensorflow graph execution*:
307 //
308 // 1) Native function call (node.type_string() is the function name). These
309 // functions are always executed on a single-device, which is the device of
310 // the function call node.
311 //
312 // 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
313 // ops) can execute on multiple devices and accept DT_RESOURCE inputs that
314 // belong to different devices. This type of functions was added in
315 // Tensorflow 2.0 Eager mode, and it has control outputs to represent
316 // side-effects that must always execute (see `control_ret` in FunctionDef).
317 //
318 // 3) SymbolicGradient has been deprecated for a while, but we still keep it and
319 // use `native` options for inlining for compatibility.
320 //
321 // We need to have distinct inlining rules for compatibility with Tensorflow v1.
322 //
323 // There are few other places in Tensorflow that could execute functions:
324 //
325 // 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
326 // functions directly via function library runtime, without going through
327 // the graph.
328 // 2) tf.data pipelines - also execute functions directly via function library
329 // runtime with custom executors.
330 struct ExpandInlineFunctionsOptions {
ExpandInlineFunctionsOptionsExpandInlineFunctionsOptions331 ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
332 using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
333 multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
334 }
335
336 InlineFunctionBodyOptions native_options;
337 InlineFunctionBodyOptions multi_device_options;
338 };
339
340 // WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
341 // workaround that will be enabled only during the function inlining unification
342 // (b/126811947). Contact ezhulenev@ if you think you need it.
343 // TODO(ezhulenev): Delete this function.
344 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
345 const ExpandInlineFunctionsOptions& options);
346
347 // For each node in "graph", if "lib" indicates that the node is a
348 // function call, inline the function body. Returns true if at least
349 // one node is inlined.
350 //
351 // This routine goes through "graph" nodes once and applies the
352 // inlining. The caller may decide to apply the inlining on "graph"
353 // multiple times by calling ExpandInlineFunctions a few times.
354 //
355 // Function calls that can't be safely inlined into the graph (ValidateInlining
356 // returns error), are ignored.
357 //
358 // TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
359 // FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
360 // lower_function_call.cc).
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph)361 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
362 return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
363 }
364
365 // Extracts function name and attributes from `call_def`
366 // `call_def` can be a native function call (where the op type is the function
367 // name) or a call through PartitionedCall/StatefulPartitionedCall.
368 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
369 NameAttrList* function);
370
371 // Extracts function name and attributes from `call_def` and invokes
372 // flr->Instantiate(name, attrs, handle).
373 // `call_def` can be a native function call (where the op type is the function
374 // name) or a call through PartitionedCall/StatefulPartitionedCall.
375 Status InstantiateFunctionCall(const NodeDef& call_def,
376 FunctionLibraryRuntime* flr,
377 FunctionLibraryRuntime::Handle* handle);
378
379 // Returns true iff `n` represents a function call. `n` can be a native
380 // function call (n.type_string() is the function name),
381 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
382 // has been deprecated for a while).
383 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
384
385 // Instantiates FunctionDef into a graph. Set *fbody to point to the
386 // FunctionBody that holds the instantiated FunctionDef.
387 Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
388 const FunctionLibraryDefinition* lib_def,
389 std::unique_ptr<FunctionBody>* fbody);
390
391 // Instantiates FunctionDef into a graph. Set *fbody to point to the
392 // FunctionBody that holds the instantiated FunctionDef. Use custom function
393 // signature lookup, in case instantiated function is not in the 'lib_def'.
394 Status FunctionDefToBodyHelper(
395 const FunctionDef& fdef, const AttrSlice& attrs,
396 const FunctionLibraryDefinition* lib_def,
397 const std::function<Status(const string&, const OpDef**)>& get_func_sig,
398 std::unique_ptr<FunctionBody>* fbody);
399 } // end namespace tensorflow
400
401 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
402