1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/container/inlined_vector.h" 25 #include "tensorflow/core/framework/attr_value.pb.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/function.pb.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/framework/op_def.pb.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 #include "tensorflow/core/lib/gtl/flatset.h" 32 33 namespace tensorflow { 34 namespace grappler { 35 36 // WARNING(ezhulenev): Currently we do not support functions with inputs or 37 // outputs instantiated into multiple tensors. This can happen if the 38 // input/output type is 'T*N' or 'list(type)'. This is enforced by multiple 39 // checks across this file and also function_optimizer.cc. InputArgExpansion and 40 // OutputArgExpansion already support lists of tensors, but that's pretty much 41 // it, all other code is written with assumption that expansions are always of 42 // size 1. MakeGrapplerFunctionItem will gracefully fail with Status error. 43 // 44 // This is a low priority feature, because in practice we don't see a lot (any 45 // at all?) functions with such arguments. Tensorflow-Eager always produces 46 // functions with plain input/output arguments. 47 48 // TODO(ezhulenev): Support inputs and outputs of type 'T*N'. 49 // TODO(ezhulenev): Support inputs and outputs of type 'list(type)'. 50 51 // Depending on the function instantiation attributes, input argument to the 52 // function might be a single tensor, list of tensors of the same type, or a 53 // list of tensors of different types. 54 // 55 // InputArgExpansion keeps track of the placeholders that were added to the 56 // function body in place of function inputs and a resolved input data type. 57 struct InputArgExpansion { 58 string input_name; 59 DataType data_type; 60 bool is_ref; 61 absl::InlinedVector<string, 1> placeholders; 62 }; 63 64 // Depending on the function instantiation attributes, output argument is mapped 65 // to one or more outputs of one of the function body nodes. 66 // 67 // OutputArgExpansion keeps track of the Identity nodes that were added to the 68 // function body to forward output tensors. Adding these output nodes allows 69 // nested function inlining and specialization (see function optimizer). 70 struct OutputArgExpansion { 71 string output_name; 72 DataType data_type; 73 bool is_ref; 74 absl::InlinedVector<string, 1> output_nodes; 75 }; 76 77 // A mapping from control output name to node name in function body graph. 78 struct ControlOutput { 79 string output_name; 80 string node_name; 81 }; 82 83 // FunctionDef uses different connectivity encoding for the function body nodes, 84 // then a GraphDef (see function.proto for details). Input name in FunctionDef 85 // can potentially represent a sequence of tensors (instead just one tensor in 86 // GraphDef), we need to expand it when converting from FunctionDef to GraphDef, 87 // and fold it back when doing backward conversion. 88 class GrapplerFunctionConnectivity { 89 public: 90 void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion); 91 void RegisterFunctionBodyOutputs(const string& node_name, 92 tensorflow::NameRangeMap&& outputs); 93 94 // Expands input encoded in FunctionDef format (name[:output][:position]) into 95 // multiple inputs in GraphDef format (name[:position]). 96 Status ExpandFunctionDefInput(const string& func_def_input, 97 std::vector<string>* graph_def_inputs) const; 98 99 // Updates Node inputs from FunctionDef to GraphDef format. 100 Status ExpandNodeInputs(NodeDef* function_body_node) const; 101 102 // When expanding inputs in function def format, single input might be 103 // expanded into multiple tensors. When converting back to the function def 104 // format from graph def format, it's always a 1-to-1 relationship. 105 // FunctionDef built from GrapplerFunctionItem is always specialized to its 106 // instantiation attributes and length of input args (and node def outputs) is 107 // known. 108 109 // Converts input name from GraphDef format (name[:position]) to the 110 // FunctionDef input format (name[:output][:position]) using registered input 111 // arg expansion and function body outputs. 112 Status AsFunctionDefInput(const string& graph_def_input, 113 string* func_def_input) const; 114 115 // Updates Node inputs from GraphDef to FunctionDef format. 116 Status AsFunctionDefNode(NodeDef* function_body_node) const; 117 118 private: 119 // Mapping from input name to input arg expansion. 120 absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_; 121 // Mapping from function body node name to output names range map. 122 absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_; 123 124 // For each placeholder added to the function instantiation graph, we keep a 125 // mapping back to the function input argument name and index. 126 struct InputArgPlaceholder { 127 string input_name; // Name of the function input argument. 128 int input_index; // Index of a tensor in the function input argument 129 // expansion, it can be greater than `0` if input 130 // argument is a list of tensors (aka list(type)). 131 }; 132 // Mapping from input arg placeholder to the function input tensor. 133 absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_; 134 }; 135 136 // Get Function type attributes using attributes of a node that instantiated 137 // a function. 138 class GrapplerFunctionItemInstantiation { 139 public: GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr)140 explicit GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr) 141 : func_instantiation_attr_(func_instantiation_attr) {} 142 143 // Get DataType from attributes by name. Return error if attribute is missing, 144 // or it doesn't define a valid data type. 145 Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const; 146 147 // Get argument data type. If data type is not explicitly defined, uses 148 // provided attribute name to look it up in function attributes. 149 Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const; 150 151 private: 152 const AttrSlice func_instantiation_attr_; // do not own 153 }; 154 155 // A special case of GrapplerItem, constructed from a TensorFlow Function. 156 class GrapplerFunctionItem : public GrapplerItem { 157 public: 158 GrapplerFunctionItem() = default; 159 160 const string& description() const; 161 162 const std::vector<InputArgExpansion>& inputs() const; 163 const InputArgExpansion& input(int i) const; 164 const std::size_t input_size() const; 165 166 const std::vector<OutputArgExpansion>& outputs() const; 167 const OutputArgExpansion& output(int i) const; 168 const std::size_t output_size() const; 169 170 const std::vector<ControlOutput>& control_outputs() const; 171 const std::size_t control_output_size() const; 172 173 const AttrSlice& func_attr() const; 174 const GraphDef& function_body() const; 175 GraphDef& mutable_function_body(); 176 177 bool is_stateful() const; 178 179 GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); 180 181 private: 182 friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&, 183 const FunctionLibraryDefinition&, int, 184 GrapplerFunctionItem*); 185 friend Status ReplaceInputWithConst(const NodeDef&, int, 186 GrapplerFunctionItem*); 187 friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&, 188 GrapplerFunctionItem*, 189 std::vector<std::pair<int, int>>*); 190 191 GrapplerFunctionItem(string func_name, string description, 192 AttrSlice func_attr, 193 std::vector<InputArgExpansion> input_arg_expansions, 194 std::vector<OutputArgExpansion> output_arg_expansions, 195 std::vector<ControlOutput> control_outputs, 196 int graph_def_version, bool is_stateful, 197 GraphDef&& function_body); 198 199 string description_; 200 AttrSlice func_attr_; // Attributes specific to function definition that 201 // produced this item (FuncDef.attr field). 202 203 std::vector<InputArgExpansion> input_arg_expansions_; 204 std::vector<OutputArgExpansion> output_arg_expansions_; 205 std::vector<ControlOutput> control_outputs_; 206 207 bool is_stateful_ = false; 208 }; 209 210 // Check if function input/output types are fully defined only at instantiation 211 // time (parametrized by its instantiation node). 212 bool HasParametrizedType(const FunctionDef& func); 213 214 // Check if a function body is parametrized by its instantiation node. Function 215 // body is parametrized, if it has at least one node with a 'placeholder' 216 // attribute. 217 bool HasParametrizedBody(const FunctionDef& func); 218 219 // Check if function has parametrized type or body. 220 bool IsParametrized(const FunctionDef& func); 221 222 // Resolve function instantiation type parameters from the attributes of the 223 // caller node. Return error if type can't be resolved. 224 Status InstantiationTypeParameters( 225 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 226 absl::flat_hash_map<string, DataType>* type_parameters); 227 228 // Resolve function instantiation body parameters (values for the function body 229 // attr placeholders) from the attributes of the caller node. Return error if 230 // type can't be resolved. 231 Status InstantiationBodyParameters( 232 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 233 absl::flat_hash_map<string, AttrValue>* body_parameters); 234 235 // Register GrapplerFunctionItem input arg expansion and function body outputs 236 // in the GrapplerFunctionConnectivity. Use function library definition to 237 // lookup function body nodes output names and ranges. 238 Status RegisterGrapplerFunctionConnectivity( 239 const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib, 240 GrapplerFunctionConnectivity* connectivity); 241 242 // Replace one of the function inputs with a constant. 243 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, 244 GrapplerFunctionItem* item); 245 246 // Removes outputs from instantiated grappler function item. Function node 247 // outputs use GraphDef output index encoding, and multiple outputs might belong 248 // to the same output argument expansion (in case of tensor list outputs). For 249 // all active function outputs that changed its output index, this function adds 250 // an output mapping (std::pair<old index, new index>). 251 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs, 252 GrapplerFunctionItem* item, 253 std::vector<std::pair<int, int>>* output_mapping); 254 255 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs. 256 257 // Make a GrapplerFunctionItem from the function definition and function 258 // instantiation attributes (caller node attributes). Returns error if the given 259 // function def cannot be converted (e.g. not all attributes are defined). 260 Status MakeGrapplerFunctionItem(const FunctionDef& func, 261 const AttrSlice& func_instantiation_attr, 262 const FunctionLibraryDefinition& flib, 263 int graph_def_version, 264 GrapplerFunctionItem* item); 265 266 // Make a GrapplerFunction item from the function definition. Function must be 267 // fully defined (no type or body parametrization). 268 // TODO(ezhulenev): Support parametrized functions without fully defined 269 // instantiation attributes? Do we ever want to optimize parametrized function 270 // without specializing it to its instantiation attributes (at least types)? 271 Status MakeGrapplerFunctionItem(const FunctionDef& func, 272 const FunctionLibraryDefinition& flib, 273 int graph_def_version, 274 GrapplerFunctionItem* item); 275 276 // Make a FunctionDef from the GrapplerFunctionItem. Use function library 277 // definition to lookup function body nodes output names and ranges. 278 Status MakeFunctionDef(const GrapplerFunctionItem& item, 279 const FunctionLibraryDefinition& flib, 280 FunctionDef* func); 281 282 } // end namespace grappler 283 } // end namespace tensorflow 284 285 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 286