1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 18 19 #include "tensorflow/core/framework/attr_value.pb.h" 20 #include "tensorflow/core/framework/function.pb.h" 21 #include "tensorflow/core/framework/graph.pb.h" 22 #include "tensorflow/core/framework/node_def.pb.h" 23 #include "tensorflow/core/framework/tensor.pb.h" 24 #include "tensorflow/core/framework/tensor_shape.pb.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/grappler/mutable_graph_view.h" 27 #include "tensorflow/core/grappler/utils.h" 28 #include "tensorflow/core/lib/core/errors.h" 29 30 namespace tensorflow { 31 namespace grappler { 32 namespace function_utils { 33 // This namespace contains utility functions for querying and modifying 34 // FunctionDefs. 35 36 // Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings 37 // have the format node_name:node_output:position (if they derive from nodes), 38 // or input_name (if they derive from an argument). 39 struct FunctionDefTensorDesc { 40 FunctionDefTensorDesc() = default; 41 42 FunctionDefTensorDesc(const string& node_name, const string& output, 43 int position); 44 45 // Parses node_name:node_output:position string into its components. 46 explicit FunctionDefTensorDesc(const string& input); 47 48 // TODO(rachelim): Add provisions to deal with special formats, like how 49 // GrapplerFunctionItem expands node output range if position is not defined 50 string full_str; 51 string node_name; 52 string node_output; 53 int position = -1; 54 }; 55 56 // Replaces all references to `from` tensor in func's nodes' inputs and retvals 57 // to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`. 58 void ReplaceReferences(const string& from, const string& to, FunctionDef* func); 59 60 // Adds a function output to the function def, ensuring that the output key 61 // is unique, and maps to output_tensor_name in the ret dict. 62 void AddFunctionOutputWithUniqueName(StringPiece prefix, 63 StringPiece output_tensor_name, 64 FunctionDef* function, DataType dt); 65 66 // Adds a node to a FunctionDef. 67 NodeDef* AddNode(StringPiece name, StringPiece op, 68 const std::vector<string>& inputs, 69 const std::vector<std::pair<string, AttrValue>>& attributes, 70 FunctionDef* fd); 71 72 // Checks whether the function contains a node with the given name. 73 bool ContainsFunctionNodeWithName(StringPiece name, 74 const FunctionDef& function); 75 76 // Checks whether the function contains a node with the given op. 77 bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); 78 79 // Checks whether the function contains an output with the given name. 80 bool ContainsFunctionOutputWithName(StringPiece name, 81 const FunctionDef& function); 82 83 // Returns the index of the function input with the given name or -1 if the 84 // function node does not exist. 85 int FindFunctionInputWithName(StringPiece name, const FunctionDef& function); 86 87 // Returns the index of the function output with the given name or -1 if the 88 // function node does not exist. 89 int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function); 90 91 // Returns the index of the function node with the given name or -1 if the 92 // function node does not exist. 93 int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); 94 95 // Returns the index of the function node with the given op or -1 if the 96 // function node does not exist. 97 int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); 98 99 // Sets the function node name using the `prefix` as a prefix while guaranteeing 100 // the name is unique across the functions nodes. 101 void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, 102 NodeDef* node); 103 104 // Checks if the function is stateful by checking the function graph for 105 // stateful ops. Because the "If" and "While" ops are conservatively marked as 106 // stateful, the check recurses into their graph to determine whether they are 107 // actually stateful. The `skip_assert` argument determines whether the "Assert" 108 // op should be treated as stateful or not. 109 bool IsFunctionStateful(const FunctionLibraryDefinition& library, 110 const FunctionDef& function_def, 111 bool skip_assert = false); 112 113 // Checks if the node is stateful. Because the "If" or "While" ops are 114 // conservatively marked as stateful, the check recurses into their graph to 115 // determine whether they are actually stateful. The `skip_assert` argument 116 // determines whether the "Assert" op should be treated as stateful or not. 117 bool IsNodeStateful(const FunctionLibraryDefinition& library, 118 const NodeDef& node, bool skip_assert = false); 119 120 } // end namespace function_utils 121 } // end namespace grappler 122 } // end namespace tensorflow 123 124 #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ 125