1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/lib/core/status.h" 24 25 namespace tensorflow { 26 27 class AttrSlice; 28 class Graph; 29 class GraphDef; 30 class NameAttrList; 31 class Node; 32 class NodeDef; 33 class OpDef; 34 35 // Debugging facility. Returns a debug string for a graph 36 // representing an instantiated function. 37 string DebugString(const Graph* g); 38 39 // Dump the contents of the "graph" to log files if the logging level is 40 // sufficiently high. 41 void DumpGraph(StringPiece label, const Graph* g); 42 43 // Convert the Graph of a function to a GraphDef. 44 // 45 // Handles renaming of nodes to avoid duplicate names which may 46 // be present after various rewriting operations. 47 void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty = false); 48 49 // Extracts function name and attributes from `call_def` 50 // `call_def` can be a native function call (where the op type is the function 51 // name) or a call through PartitionedCall/StatefulPartitionedCall. 52 Status NameAndAttrsFromFunctionCall(const NodeDef& call_def, 53 NameAttrList* function); 54 55 // A few hand-crafted optimization on the instantiated function body 56 // (a Graph*). 57 58 // Removes nodes that are 59 // 1. not stateful; and 60 // 2. not _Arg; and 61 // 3. not reachable from _Retval. 62 // 63 // This function is triggered by function inlining, unlike 'PruneFunctionBody' 64 // it doesn't preserve nodes that are reachable from control returns. Function 65 // inlining is responsible for connecting control return nodes with the nodes 66 // that have input control edges from the inlined function call node. 67 // 68 // Assuming that automatic control dependency tracking is correct, absence of 69 // outgoing control edge from the function call node means that no one needs to 70 // observe side-effect that might have been generated by the function (see 71 // documentation in common_runtime/function.cc for details). 72 // 73 // Returns true iff any node is removed from "g". 74 bool RemoveDeadNodes(Graph* g); 75 76 // Find a pattern: 77 // src -(in)-> node -(out)-> dst, where 78 // 1) node is an identity node; 79 // 2) in is the only incoming data edge; 80 // 3) out is the only outgoing data edge; 81 // 82 // Rewrites the above pattern with src->dst and relevant data 83 // dependencies updated. Repeat the process until no such pattern 84 // left. 85 bool RemoveIdentityNodes(Graph* g); 86 87 // Rewrites _ListToArray and _ArrayToList to a set of Identity nodes. 88 bool RemoveListArrayConverter(Graph* g); 89 90 // Extracts function name and attributes from `call_def` and invokes 91 // flr->Instantiate(name, attrs, handle). 92 // `call_def` can be a native function call (where the op type is the function 93 // name) or a call through PartitionedCall/StatefulPartitionedCall. 94 Status InstantiateFunctionCall(const NodeDef& call_def, 95 FunctionLibraryRuntime* flr, 96 FunctionLibraryRuntime::Handle* handle); 97 98 // Returns true iff `n` represents a function call. `n` can be a native 99 // function call (n.type_string() is the function name), 100 // a PartitionedCall/StatefulPartitionedCall, or a SymbolicGradient (which 101 // has been deprecated for a while). 102 bool IsFunctionCall(const FunctionLibraryDefinition& lib_def, const Node& n); 103 } // end namespace tensorflow 104 105 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_UTILS_H_ 106