• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/types/optional.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/device_mgr.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/graph/graph.h"
29 #include "tensorflow/core/protobuf/config.pb.h"
30 
31 namespace tensorflow {
32 
33 static constexpr const char* const kNoInlineAttr = "_noinline";
34 
35 // Get default customizable kernel creator if set
36 const CustomKernelCreator* GetDefaultCustomKernelCreator();
37 
38 // Registers a default customizable kernel creator for a function call.
39 //
40 // If c->CanCreateKernel returns false, we still fall back to an executor-based
41 // interpreter op kernel to execute a function. Else c->CreateKernel() can be
42 // used to create a kernel that will compile the function with XLA and run the
43 // resulting program.
44 //
45 // TODO(zhifengc/phawkins): b/32379046
46 void RegisterDefaultCustomKernelCreator(CustomKernelCreator* c);
47 
48 // Creates a FunctionLibraryRuntime, which instantiates functions
49 // defined in "lib_def" and executes functions on the "device".
50 // "device_mgr" must contain the "device". If not nullptr,
51 // "custom_kernel_creator" is consulted by the returned runtime to
52 // create kernels.
53 //
54 // The returned object does not take ownerships of "device" or
55 // "lib_def".  The caller must ensure "device" and "lib_def" outlives
56 // the returned object.
57 //
58 // The "parent" is a pointer to the ProcessFunctionLibraryRuntime object that
59 // typically owns the created FunctionLibraryRuntime object. The parent pointer
60 // is not owned by the FunctionLibraryRuntime object.
61 std::unique_ptr<FunctionLibraryRuntime> NewFunctionLibraryRuntime(
62     const DeviceMgr* device_mgr, Env* env, const ConfigProto* config,
63     Device* device, int graph_def_version,
64     const FunctionLibraryDefinition* lib_def, thread::ThreadPool* thread_pool,
65     const OptimizerOptions& optimizer_options,
66     const CustomKernelCreator* custom_kernel_creator,
67     const SessionMetadata* session_metadata,
68     ProcessFunctionLibraryRuntime* parent);
69 
70 // FunctionLibraryRuntime::GetFunctionBody returns a description of an
71 // instantiated function that is represented as a Graph with arg/ret
72 // nodes annotated.
73 struct FunctionBody {
74   FunctionDef fdef;
75   Graph* graph = nullptr;  // owned.
76   DataTypeVector arg_types;
77   DataTypeVector ret_types;
78   // arg_nodes[i] contains the i'th function input. In other words,
79   // GetNodeAttr(arg_nodes[i]->attrs(), "index") == i.
80   gtl::InlinedVector<Node*, 4> arg_nodes;
81   // ret_nodes[i] contains the i'th function output. In other words,
82   // GetNodeAttr(ret_nodes[i]->attrs(), "index") == i.
83   gtl::InlinedVector<Node*, 4> ret_nodes;
84   gtl::InlinedVector<Node*, 4> control_ret_nodes;
85 
FunctionBodyFunctionBody86   FunctionBody() {}
87   FunctionBody(const FunctionDef& f, DataTypeSlice arg_types,
88                DataTypeSlice ret_types, Graph* g);
89   ~FunctionBody();
90 };
91 
92 // Debugging facility.  Returns a debug string for a graph
93 // representing an instantiated function.
94 string DebugString(const Graph* g);
95 
96 // A few hand-crafted optimization on the instantiated function body
97 // (a Graph*).
98 
99 // Removes nodes that are
100 //   1. not stateful; and
101 //   2. not _Arg; and
102 //   3. not reachable from _Retval.
103 //
104 // This function is triggered by function inlining, unlike 'PruneFunctionBody'
105 // it doesn't preserve nodes that are reachable from control returns. Function
106 // inlining is responsible for connecting control return nodes with the nodes
107 // that have input control edges from the inlined function call node.
108 //
109 // Assuming that automatic control dependency tracking is correct, absence of
110 // outgoing control edge from the function call node means that no one needs to
111 // observe side-effect that might have been generated by the function (see
112 // documentation in common_runtime/function.cc for details).
113 //
114 // Returns true iff any node is removed from "g".
115 bool RemoveDeadNodes(Graph* g);
116 
117 // Find a pattern:
118 //   src -(in)-> node -(out)-> dst, where
119 // 1) node is an identity node;
120 // 2) in is the only incoming data edge;
121 // 3) out is the only outgoing data edge;
122 //
123 // Rewrites the above pattern with src->dst and relevant data
124 // dependencies updated. Repeat the process until no such pattern
125 // left.
126 bool RemoveIdentityNodes(Graph* g);
127 
128 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
129 bool RemoveListArrayConverter(Graph* g);
130 
131 // Dump the contents of the "graph" to log files if the logging level is
132 // sufficiently high.
133 void DumpGraph(StringPiece label, const Graph* g);
134 
135 // Applies graph rewrite optimization such as inlining, dead code
136 // removal, etc.
137 //
138 // **g is a graph constructed based on the runtime library 'lib'.
139 // OptimizeGraph mutates **g extensively and replaces '*g' with a
140 // complete copy. Therefore, the caller should not keep any references
141 // to nodes *g.
142 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g,
143                    const GraphOptimizer::Options& graph_optimizer_options);
144 void OptimizeGraph(FunctionLibraryRuntime* lib, std::unique_ptr<Graph>* g);
145 
146 // Convert the Graph of a function to a GraphDef.
147 //
148 // Handles renaming of nodes to avoid duplicate names which may
149 // be present after various rewriting operations.
150 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
151 
152 // Given a numerical function "f", returns another numerical function
153 // "g", such that if "f" takes N inputs and produces M outputs, "g"
154 // takes N + M inputs and produces N outputs. I.e., if
155 //   (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
156 // g is a function which is
157 //   (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
158 //                                     dL/dy1, dL/dy2, ..., dL/dy_M),
159 // where L is a scalar-value function of (...x_i...).
160 //
161 // TODO(zhifengc): Asks math expert to say the comment again.
162 std::unique_ptr<FunctionBody> SymbolicGradient(const FunctionBody& f);
163 
164 // Optionally override device assignment for nodes added to the graph for
165 // inlined functions:
166 // (1) Identity nodes added in place of function input arguments.
167 // (2) Identity nodes added in place of function return values.
168 // (3) Special NoOp nodes that enforce side-effects execution order.
169 // (4) All nodes inside function body specified in FunctionDef.
170 class InlinedFunctionBodyPlacer {
171  public:
172   virtual ~InlinedFunctionBodyPlacer() = default;
173 
174   virtual absl::optional<string> InputNodeDevice(int input_index) const = 0;
175   virtual absl::optional<string> OutputNodeDevice(int output_index) const = 0;
176   // Returns true if the added input/output identity nodes should be colocated
177   // with the corresponding input/output from the function body.
178   virtual bool ColocateInputOutputIdentities() const = 0;
179   virtual absl::optional<string> ControlNodeDevice() const = 0;
180   virtual absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const = 0;
181 
182   // Place input nodes on the same device as the corresponding caller input
183   // node. Do not specify any placement for all other nodes.
184   static std::unique_ptr<InlinedFunctionBodyPlacer> DefaultPlacer(
185       const Graph& graph, const Node& caller);
186 
187   // Place all nodes on the same device as caller node.
188   static std::unique_ptr<InlinedFunctionBodyPlacer> SingleDevicePlacer(
189       const Graph& graph, const Node& caller);
190 
191   // Place input nodes on the same device as the corresponding caller input
192   // node. Do not place output node. Place control nodes on the same device as
193   // caller node. For all function body nodes overrides job, replica and task
194   // parts of the device assignment to match function caller node.
195   static std::unique_ptr<InlinedFunctionBodyPlacer> MultiDevicePlacer(
196       const Graph& graph, const Node& caller);
197 
198   using Factory = std::function<std::unique_ptr<InlinedFunctionBodyPlacer>(
199       const Graph&, const Node&)>;
200 
201   struct Config {
202     string name;
203     Factory get;
204   };
205 
Default()206   static Config Default() { return {"default", DefaultPlacer}; }
SingleDevice()207   static Config SingleDevice() { return {"single_device", SingleDevicePlacer}; }
MultiDevice()208   static Config MultiDevice() { return {"multi_device", MultiDevicePlacer}; }
209 };
210 
211 struct InlineFunctionBodyOptions {
212   // All nodes that have incoming control edge *from* the function call node,
213   // will be forwarded to the "output control node". There are two options for
214   // choosing which nodes will have a control edge *to* the "output control
215   // node":
216   //   a) control returns            (`control_ret` field in FunctionDef)
217   //   b) data returns               (`ret` field in FunctionDef)
218   enum class OutputControlSource { kDataOutputs, kControlOutputs };
219 
220   // Keep a node in a graph with the same name as the function call node:
221   //
222   // a) DoNotKeep: Function call node is fully inlined, and there is no node in
223   //    a graph with the same name.
224   //
225   // b) Fetchable: Add an IdentityN node to the graph in place of the inlined
226   //    function call node. It will have a control edge from inlined
227   //    'output_control_node' and data edges from function output nodes.
228   //    The IdentityN node will be placed on the same device as the caller node.
229   //
230   //    This is mostly for compatibility with Tensorflow v1 and sessions.
231   //    When we prepare a graph for execution in
232   //    GraphExecutionState::MakeForBaseGraph we don't know what nodes will be
233   //    fetched, so we can't safely remove any of them. When graph executed as a
234   //    function it has 'Retval' nodes for all fetched tensors, and we can
235   //    safely inline function calls.
236   //
237   // c) Targetable: Add a NoOp node to the graph in place of the inlined
238   //    function call node. It will have a control edge from inline
239   //    'output_control_node' and no data edges. NoOp node will be placed on the
240   //    same device as the caller node. This will keep the inlined function call
241   //    node a valid 'session.run' target, and also will keep it a valid control
242   //    output node.
243   enum class KeepCallerNode { kDoNotKeep, kFetchable, kTargetable };
244 
245   // If 'true' function inlining is completely disabled. This allows to control
246   // function inlining for different types of function calls (see
247   // 'ExpandInlineFunctionsOptions' below).
248   bool disable_inlining = false;
249   // Ignore '_noinline' function attribute.
250   bool ignore_noinline = false;
251   // If 'true' function inlining will inline functions in implementation
252   // selection group. Normally those functions should not be inlined; they will
253   // be handled by Grappler.
254   bool inline_impl_selection_group_functions = false;
255   // Controls if we want to keep a node with the name as the function call node
256   // in a graph after function inlining.
257   KeepCallerNode keep_caller_node = KeepCallerNode::kDoNotKeep;
258   // For compatibility with Tensorflow v1 by default we will use data outputs.
259   // Control returns were added to Tensorflow v2 with automatic control
260   // dependencies tracking in Eager mode.
261   OutputControlSource output_control_src = OutputControlSource::kDataOutputs;
262   // Inlined function body placer decides what requested device assignments
263   // should be added to the nodes added to the graph. See documentation above
264   // for available strategies.
265   InlinedFunctionBodyPlacer::Config inlined_function_body_placer =
266       InlinedFunctionBodyPlacer::Default();
267   // If true, frame names in the function body will be
268   // made unique in the resulting graph (e.g. by prepending a unique prefix).
269   // NOTE(mrry): Only set this option to false when there is a single function
270   // call in the graph (e.g. when making a remote function call via
271   // ClusterFunctionLibraryRuntime). This option is provided because the graph
272   // partitioner generates frame names that must remain unmodified across all
273   // partitions of a multi-device function.
274   bool uniquify_frame_names = true;
275 
276   // A human-readable debug string for this options.
277   string DebugString() const;
278 };
279 
280 // Returns 'Status::OK()' iff the function '*fbody' can be inlined at 'node'
281 // based on the type signature of 'node' and 'fbody':
282 //
283 // (1) Caller node has the same number of inputs and outputs as the function.
284 // (2) Caller node inputs and outputs have the same data types as function
285 //     inputs and returns.
286 // (3) Validation rules defined in InlineFunctionBodyOptions.
287 //
288 // If function can't be safely inlined, returns error message with details why
289 // inlining is not possible or safe.
290 Status ValidateInlining(const Node* node, const FunctionBody* fbody,
291                         const InlineFunctionBodyOptions& options);
292 
293 // Given a "caller" in graph "g", which is a function call of a function
294 // to "fbody". Replaces the "caller" with fbody->graph and connects
295 // edges properly. "override_device" specifies whether inlining should replace
296 // explicitly specified devices inside fbody with the callee's device.
297 //
298 // Returns 'Status::OK()' if function was successfully inlined into the graph.
299 // If function inlining is not possible returns an error with a reason, and
300 // leaves the graph in unmodified state.
301 Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
302                           Node* caller, const FunctionBody* fbody,
303                           const InlineFunctionBodyOptions& options);
304 
305 // There are three types of function calls that could be invoked during
306 // *Tensorflow graph execution*:
307 //
308 // 1) Native function call (node.type_string() is the function name). These
309 //    functions are always executed on a single-device, which is the device of
310 //    the function call node.
311 //
312 // 2) Multi-device function calls (PartitionedCall or StatefulPartitionedCall
313 //    ops) can execute on multiple devices and accept DT_RESOURCE inputs that
314 //    belong to different devices. This type of functions was added in
315 //    Tensorflow 2.0 Eager mode, and it has control outputs to represent
316 //    side-effects that must always execute (see `control_ret` in FunctionDef).
317 //
318 // 3) SymbolicGradient has been deprecated for a while, but we still keep it and
319 //    use `native` options for inlining for compatibility.
320 //
321 // We need to have distinct inlining rules for compatibility with Tensorflow v1.
322 //
323 // There are few other places in Tensorflow that could execute functions:
324 //
325 // 1) common_runtime/eager/kernel_and_device.{h,cc} - executes "top level"
326 //    functions directly via function library runtime, without going through
327 //    the graph.
328 // 2) tf.data pipelines - also execute functions directly via function library
329 //    runtime with custom executors.
330 struct ExpandInlineFunctionsOptions {
ExpandInlineFunctionsOptionsExpandInlineFunctionsOptions331   ExpandInlineFunctionsOptions() : native_options(), multi_device_options() {
332     using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
333     multi_device_options.output_control_src = OutputControlSrc::kControlOutputs;
334   }
335 
336   InlineFunctionBodyOptions native_options;
337   InlineFunctionBodyOptions multi_device_options;
338 };
339 
340 // WARNING(ezhulenev): PLEASE DO NOT USE THIS FUNCTION. This is a temporary
341 // workaround that will be enabled only during the function inlining unification
342 // (b/126811947). Contact ezhulenev@ if you think you need it.
343 // TODO(ezhulenev): Delete this function.
344 bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
345                            const ExpandInlineFunctionsOptions& options);
346 
347 // For each node in "graph", if "lib" indicates that the node is a
348 // function call, inline the function body. Returns true if at least
349 // one node is inlined.
350 //
351 // This routine goes through "graph" nodes once and applies the
352 // inlining. The caller may decide to apply the inlining on "graph"
353 // multiple times by calling ExpandInlineFunctions a few times.
354 //
355 // Function calls that can't be safely inlined into the graph (ValidateInlining
356 // returns error), are ignored.
357 //
358 // TODO(ezhulenev): We do not FunctionLibraryRuntime for this. We need just the
359 // FunctionLibraryDefinition and FunctionDefToBodyHelper to implement this (see
360 // lower_function_call.cc).
ExpandInlineFunctions(FunctionLibraryRuntime * lib,Graph * graph)361 inline bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph) {
362   return ExpandInlineFunctions(lib, graph, ExpandInlineFunctionsOptions());
363 }
364 
365 // Extracts function name and attributes from `call_def`
366 // `call_def` can be a native function call (where the op type is the function
367 // name) or a call through PartitionedCall/StatefulPartitionedCall.
368 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
369                                     NameAttrList* function);
370 
371 // Extracts function name and attributes from `call_def` and invokes
372 // flr->Instantiate(name, attrs, handle).
373 // `call_def` can be a native function call (where the op type is the function
374 // name) or a call through PartitionedCall/StatefulPartitionedCall.
375 Status InstantiateFunctionCall(const NodeDef& call_def,
376                                FunctionLibraryRuntime* flr,
377                                FunctionLibraryRuntime::Handle* handle);
378 
379 // Returns true iff `n` represents a function call. `n` can be a native
380 // function call (n.type_string() is the function name),
381 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
382 // has been deprecated for a while).
383 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
384 
385 // Instantiates FunctionDef into a graph. Set *fbody to point to the
386 // FunctionBody that holds the instantiated FunctionDef.
387 Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs,
388                                const FunctionLibraryDefinition* lib_def,
389                                std::unique_ptr<FunctionBody>* fbody);
390 
391 // Instantiates FunctionDef into a graph. Set *fbody to point to the
392 // FunctionBody that holds the instantiated FunctionDef. Use custom function
393 // signature lookup, in case instantiated function is not in the 'lib_def'.
394 Status FunctionDefToBodyHelper(
395     const FunctionDef& fdef, const AttrSlice& attrs,
396     const FunctionLibraryDefinition* lib_def,
397     const std::function<Status(const string&, const OpDef**)>& get_func_sig,
398     std::unique_ptr<FunctionBody>* fbody);
399 }  // end namespace tensorflow
400 
401 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
402