• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "tensorflow/core/framework/function.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 
27 class AttrSlice;
28 class Graph;
29 class GraphDef;
30 class NameAttrList;
31 class Node;
32 class NodeDef;
33 class OpDef;
34 
35 // Debugging facility.  Returns a debug string for a graph
36 // representing an instantiated function.
37 string DebugString(const Graph* g);
38 
39 // Dump the contents of the "graph" to log files if the logging level is
40 // sufficiently high.
41 void DumpGraph(StringPiece label, const Graph* g);
42 
43 // Convert the Graph of a function to a GraphDef.
44 //
45 // Handles renaming of nodes to avoid duplicate names which may
46 // be present after various rewriting operations.
47 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false);
48 
49 // Extracts function name and attributes from `call_def`
50 // `call_def` can be a native function call (where the op type is the function
51 // name) or a call through PartitionedCall/StatefulPartitionedCall.
52 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def,
53                                     NameAttrList* function);
54 
55 // A few hand-crafted optimization on the instantiated function body
56 // (a Graph*).
57 
58 // Removes nodes that are
59 //   1. not stateful; and
60 //   2. not _Arg; and
61 //   3. not reachable from _Retval.
62 //
63 // This function is triggered by function inlining, unlike 'PruneFunctionBody'
64 // it doesn't preserve nodes that are reachable from control returns. Function
65 // inlining is responsible for connecting control return nodes with the nodes
66 // that have input control edges from the inlined function call node.
67 //
68 // Assuming that automatic control dependency tracking is correct, absence of
69 // outgoing control edge from the function call node means that no one needs to
70 // observe side-effect that might have been generated by the function (see
71 // documentation in common_runtime/function.cc for details).
72 //
73 // Returns true iff any node is removed from "g".
74 bool RemoveDeadNodes(Graph* g);
75 
76 // Find a pattern:
77 //   src -(in)-> node -(out)-> dst, where
78 // 1) node is an identity node;
79 // 2) in is the only incoming data edge;
80 // 3) out is the only outgoing data edge;
81 //
82 // Rewrites the above pattern with src->dst and relevant data
83 // dependencies updated. Repeat the process until no such pattern
84 // left.
85 bool RemoveIdentityNodes(Graph* g);
86 
87 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes.
88 bool RemoveListArrayConverter(Graph* g);
89 
90 // Extracts function name and attributes from `call_def` and invokes
91 // flr->Instantiate(name, attrs, handle).
92 // `call_def` can be a native function call (where the op type is the function
93 // name) or a call through PartitionedCall/StatefulPartitionedCall.
94 Status InstantiateFunctionCall(const NodeDef& call_def,
95                                FunctionLibraryRuntime* flr,
96                                FunctionLibraryRuntime::Handle* handle);
97 
98 // Returns true iff `n` represents a function call. `n` can be a native
99 // function call (n.type_string() is the function name),
100 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which
101 // has been deprecated for a while).
102 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n);
103 }  // end namespace tensorflow
104 
105 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_
106