1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 18 19 #include <memory> 20 #include <string> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/container/flat_hash_set.h" 24 #include "absl/container/inlined_vector.h" 25 #include "tensorflow/core/framework/attr_value.pb.h" 26 #include "tensorflow/core/framework/function.h" 27 #include "tensorflow/core/framework/function.pb.h" 28 #include "tensorflow/core/framework/node_def_util.h" 29 #include "tensorflow/core/framework/op_def.pb.h" 30 #include "tensorflow/core/grappler/grappler_item.h" 31 #include "tensorflow/core/lib/gtl/flatset.h" 32 33 namespace tensorflow { 34 namespace grappler { 35 36 // Function input argument instantiated into an '_Arg' node in the function body 37 // graph, with an 'index' attribute corresponding to the input position. 38 struct InputArgInstantiation { InputArgInstantiationInputArgInstantiation39 InputArgInstantiation(string node_name, DataType data_type) 40 : node_name(std::move(node_name)), data_type(data_type) {} 41 string node_name; 42 DataType data_type; 43 }; 44 45 // Function output instantiated into a '_Retval' node in the function body 46 // graph, with an 'index' attribute corresponding to the output position. 47 struct OutputArgInstantiation { OutputArgInstantiationOutputArgInstantiation48 OutputArgInstantiation(string node_name, DataType data_type) 49 : node_name(std::move(node_name)), data_type(data_type) {} 50 string node_name; 51 DataType data_type; 52 }; 53 54 // A mapping from control output name to node name in function body graph. 55 struct ControlOutput { 56 string output_name; 57 string node_name; 58 }; 59 60 // A special case of GrapplerItem, constructed from a TensorFlow Function. 61 class GrapplerFunctionItem : public GrapplerItem { 62 public: 63 GrapplerFunctionItem() = default; 64 65 const string& description() const; 66 67 const std::vector<InputArgInstantiation>& inputs() const; 68 const InputArgInstantiation& input(int i) const; 69 const std::size_t input_size() const; 70 71 const std::vector<OutputArgInstantiation>& outputs() const; 72 const OutputArgInstantiation& output(int i) const; 73 const std::size_t output_size() const; 74 75 const std::vector<ControlOutput>& control_outputs() const; 76 const std::size_t control_output_size() const; 77 78 const AttrSlice& func_attr() const; 79 const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const; 80 const GraphDef& function_body() const; 81 GraphDef& mutable_function_body(); 82 83 bool is_stateful() const; 84 85 GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other); 86 87 private: 88 friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&, 89 const FunctionLibraryDefinition&, int, 90 GrapplerFunctionItem*); 91 friend Status ReplaceInputWithConst(const NodeDef&, int, 92 GrapplerFunctionItem*); 93 friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&, 94 GrapplerFunctionItem*, 95 std::vector<std::pair<int, int>>*); 96 97 GrapplerFunctionItem(string func_name, string description, 98 AttrSlice func_attr, 99 std::vector<const FunctionDef::ArgAttrs*> arg_attr, 100 std::vector<InputArgInstantiation> input_args, 101 std::vector<OutputArgInstantiation> output_args, 102 std::vector<ControlOutput> control_outputs, 103 int graph_def_version, bool is_stateful, 104 GraphDef&& function_body); 105 106 string description_; 107 AttrSlice func_attr_; // Attributes specific to function definition that 108 // produced this item (FuncDef.attr field). 109 110 // Attributes of function arguments 111 std::vector<const FunctionDef::ArgAttrs*> arg_attr_; 112 113 std::vector<InputArgInstantiation> input_args_; 114 std::vector<OutputArgInstantiation> output_args_; 115 std::vector<ControlOutput> control_outputs_; 116 117 bool is_stateful_ = false; 118 }; 119 120 // Check if function input/output types are fully defined only at instantiation 121 // time (parametrized by its instantiation node). 122 bool HasParametrizedType(const FunctionDef& func); 123 124 // Check if a function body is parametrized by its instantiation node. Function 125 // body is parametrized, if it has at least one node with a 'placeholder' 126 // attribute. 127 bool HasParametrizedBody(const FunctionDef& func); 128 129 // Check if function has parametrized type or body. 130 bool IsParametrized(const FunctionDef& func); 131 132 // Resolve function instantiation type parameters from the attributes of the 133 // caller node. Return error if type can't be resolved. 134 Status InstantiationTypeParameters( 135 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 136 absl::flat_hash_map<string, DataType>* type_parameters); 137 138 // Resolve function instantiation body parameters (values for the function body 139 // attr placeholders) from the attributes of the caller node. Return error if 140 // type can't be resolved. 141 Status InstantiationBodyParameters( 142 const FunctionDef& func, const AttrSlice& func_instantiation_attr, 143 absl::flat_hash_map<string, AttrValue>* body_parameters); 144 145 // Replace one of the function inputs with a constant. 146 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index, 147 GrapplerFunctionItem* item); 148 149 // Removes outputs from instantiated grappler function item. For all active 150 // function outputs that changed its output index, this function adds an output 151 // mapping (std::pair<old index, new index>). 152 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs, 153 GrapplerFunctionItem* item, 154 std::vector<std::pair<int, int>>* output_mapping); 155 156 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs. 157 158 // Make a GrapplerFunctionItem from the function definition and function 159 // instantiation attributes (caller node attributes). Returns error if the given 160 // function def cannot be converted (e.g. not all attributes are defined). 161 Status MakeGrapplerFunctionItem(const FunctionDef& func, 162 const AttrSlice& func_instantiation_attr, 163 const FunctionLibraryDefinition& flib, 164 int graph_def_version, 165 GrapplerFunctionItem* item); 166 167 // Make a GrapplerFunction item from the function definition. Function must be 168 // fully defined (no type or body parametrization). 169 // TODO(ezhulenev): Support parametrized functions without fully defined 170 // instantiation attributes? Do we ever want to optimize parametrized function 171 // without specializing it to its instantiation attributes (at least types)? 172 Status MakeGrapplerFunctionItem(const FunctionDef& func, 173 const FunctionLibraryDefinition& flib, 174 int graph_def_version, 175 GrapplerFunctionItem* item); 176 177 // Make a FunctionDef from the GrapplerFunctionItem. Use function library 178 // definition to lookup function body nodes output names and ranges. 179 Status MakeFunctionDef(const GrapplerFunctionItem& item, 180 const FunctionLibraryDefinition& flib, 181 FunctionDef* func); 182 183 } // end namespace grappler 184 } // end namespace tensorflow 185 186 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_ 187