1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 18 19 #include "tensorflow/core/framework/attr_value.pb.h" 20 #include "tensorflow/core/framework/function.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/node_def.pb.h" 23 #include "tensorflow/core/framework/tensor.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.pb.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/grappler/mutable_graph_view.h" 27 #include "tensorflow/core/grappler/utils.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 namespace function_utils { 33 // This namespace contains utility functions for querying and modifying 34 // FunctionDefs. 35 36 // Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings 37 // have the format node_name:node_output:position (if they derive from nodes), 38 // or input_name (if they derive from an argument). 39 struct FunctionDefTensorDesc { 40 FunctionDefTensorDesc() = default; 41 42 FunctionDefTensorDesc(const string& node_name, const string& output, 43 int position); 44 45 // Parses node_name:node_output:position string into its components. 46 explicit FunctionDefTensorDesc(const string& input); 47 48 // TODO(rachelim): Add provisions to deal with special formats, like how 49 // GrapplerFunctionItem expands node output range if position is not defined 50 string full_str; 51 string node_name; 52 string node_output; 53 int position = -1; 54 }; 55 56 // Replaces all references to `from` tensor in func's nodes' inputs and retvals 57 // to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`. 58 void ReplaceReferences(const string& from, const string& to, FunctionDef* func); 59 60 // Adds a function output to the function def, ensuring that the output key 61 // is unique, and maps to output_tensor_name in the ret dict. 62 void AddFunctionOutputWithUniqueName(StringPiece prefix, 63 StringPiece output_tensor_name, 64 FunctionDef* fdef, DataType dtype); 65 66 // Adds an input to a FunctionDef. 67 OpDef_ArgDef* AddFunctionInput(const string& name, FunctionDef* fdef, 68 DataType dtype); 69 70 // Adds a node to a FunctionDef. 71 NodeDef* AddNode(StringPiece name, StringPiece op, 72 const std::vector<string>& inputs, 73 const std::vector<std::pair<string, AttrValue>>& attributes, 74 FunctionDef* fd); 75 76 // Checks whether the function contains a node with the given name. 77 bool ContainsFunctionNodeWithName(StringPiece name, 78 const FunctionDef& function); 79 80 // Checks whether the function contains a node with the given op. 81 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); 82 83 // Checks whether the function contains an output with the given name. 84 bool ContainsFunctionOutputWithName(StringPiece name, 85 const FunctionDef& function); 86 87 // Returns the index of the function input with the given name or -1 if the 88 // function node does not exist. 89 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function); 90 91 // Returns the index of the function output with the given name or -1 if the 92 // function node does not exist. 93 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function); 94 95 // Returns the index of the function node with the given name or -1 if the 96 // function node does not exist. 97 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); 98 99 // Returns the index of the function node with the given op or -1 if the 100 // function node does not exist. 101 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); 102 103 // Sets the function node name using the `prefix` as a prefix while guaranteeing 104 // the name is unique across the functions nodes. 105 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, 106 NodeDef* node); 107 108 // Checks if the function is stateful by checking the function graph for 109 // stateful ops. Because the "If" and "While" ops are conservatively marked as 110 // stateful, the check recurses into their graph to determine whether they are 111 // actually stateful. The `skip_assert` argument determines whether the "Assert" 112 // op should be treated as stateful or not. 113 bool IsFunctionStateful(const FunctionLibraryDefinition& library, 114 const FunctionDef& function_def, 115 bool skip_assert = false); 116 117 // Checks if the node is stateful. Because the "If" or "While" ops are 118 // conservatively marked as stateful, the check recurses into their graph to 119 // determine whether they are actually stateful. The `skip_assert` argument 120 // determines whether the "Assert" op should be treated as stateful or not. 121 bool IsNodeStateful(const FunctionLibraryDefinition& library, 122 const NodeDef& node, bool skip_assert = false); 123 124 } // end namespace function_utils 125 } // end namespace grappler 126 } // end namespace tensorflow 127 128 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 129