• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
36 // WARNING(ezhulenev): Currently we do not support functions with inputs or
37 // outputs instantiated into multiple tensors. This can happen if the
38 // input/output type is 'T*N' or 'list(type)'. This is enforced by multiple
39 // checks across this file and also function_optimizer.cc. InputArgExpansion and
40 // OutputArgExpansion already support lists of tensors, but that's pretty much
41 // it, all other code is written with assumption that expansions are always of
42 // size 1. MakeGrapplerFunctionItem will gracefully fail with Status error.
43 //
44 // This is a low priority feature, because in practice we don't see a lot (any
45 // at all?) functions with such arguments. Tensorflow-Eager always produces
46 // functions with plain input/output arguments.
47 
48 // TODO(ezhulenev): Support inputs and outputs of type 'T*N'.
49 // TODO(ezhulenev): Support inputs and outputs of type 'list(type)'.
50 
51 // Depending on the function instantiation attributes, input argument to the
52 // function might be a single tensor, list of tensors of the same type, or a
53 // list of tensors of different types.
54 //
55 // InputArgExpansion keeps track of the placeholders that were added to the
56 // function body in place of function inputs and a resolved input data type.
57 struct InputArgExpansion {
58   string input_name;
59   DataType data_type;
60   bool is_ref;
61   absl::InlinedVector<string, 1> placeholders;
62 };
63 
64 // Depending on the function instantiation attributes, output argument is mapped
65 // to one or more outputs of one of the function body nodes.
66 //
67 // OutputArgExpansion keeps track of the Identity nodes that were added to the
68 // function body to forward output tensors. Adding these output nodes allows
69 // nested function inlining and specialization (see function optimizer).
70 struct OutputArgExpansion {
71   string output_name;
72   DataType data_type;
73   bool is_ref;
74   absl::InlinedVector<string, 1> output_nodes;
75 };
76 
77 // A mapping from control output name to node name in function body graph.
78 struct ControlOutput {
79   string output_name;
80   string node_name;
81 };
82 
83 // FunctionDef uses different connectivity encoding for the function body nodes,
84 // then a GraphDef (see function.proto for details). Input name in FunctionDef
85 // can potentially represent a sequence of tensors (instead just one tensor in
86 // GraphDef), we need to expand it when converting from FunctionDef to GraphDef,
87 // and fold it back when doing backward conversion.
88 class GrapplerFunctionConnectivity {
89  public:
90   void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
91   void RegisterFunctionBodyOutputs(const string& node_name,
92                                    tensorflow::NameRangeMap&& outputs);
93 
94   // Expands input encoded in FunctionDef format (name[:output][:position]) into
95   // multiple inputs in GraphDef format (name[:position]).
96   Status ExpandFunctionDefInput(const string& func_def_input,
97                                 std::vector<string>* graph_def_inputs) const;
98 
99   // Updates Node inputs from FunctionDef to GraphDef format.
100   Status ExpandNodeInputs(NodeDef* function_body_node) const;
101 
102   // When expanding inputs in function def format, single input might be
103   // expanded into multiple tensors. When converting back to the function def
104   // format from graph def format, it's always a 1-to-1 relationship.
105   // FunctionDef built from GrapplerFunctionItem is always specialized to its
106   // instantiation attributes and length of input args (and node def outputs) is
107   // known.
108 
109   // Converts input name from GraphDef format (name[:position]) to the
110   // FunctionDef input format (name[:output][:position]) using registered input
111   // arg expansion and function body outputs.
112   Status AsFunctionDefInput(const string& graph_def_input,
113                             string* func_def_input) const;
114 
115   // Updates Node inputs from GraphDef to FunctionDef format.
116   Status AsFunctionDefNode(NodeDef* function_body_node) const;
117 
118  private:
119   // Mapping from input name to input arg expansion.
120   absl::flat_hash_map<string, InputArgExpansion> input_arg_expansions_;
121   // Mapping from function body node name to output names range map.
122   absl::flat_hash_map<string, tensorflow::NameRangeMap> function_body_outputs_;
123 
124   // For each placeholder added to the function instantiation graph, we keep a
125   // mapping back to the function input argument name and index.
126   struct InputArgPlaceholder {
127     string input_name;  // Name of the function input argument.
128     int input_index;    // Index of a tensor in the function input argument
129                         // expansion, it can be greater than `0` if input
130                         // argument is a list of tensors (aka list(type)).
131   };
132   // Mapping from input arg placeholder to the function input tensor.
133   absl::flat_hash_map<string, InputArgPlaceholder> input_arg_placeholders_;
134 };
135 
136 // Get Function type attributes using attributes of a node that instantiated
137 // a function.
138 class GrapplerFunctionItemInstantiation {
139  public:
GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr)140   explicit GrapplerFunctionItemInstantiation(AttrSlice func_instantiation_attr)
141       : func_instantiation_attr_(func_instantiation_attr) {}
142 
143   // Get DataType from attributes by name. Return error if attribute is missing,
144   // or it doesn't define a valid data type.
145   Status GetTypeAttr(const string& type_attr_name, DataType* data_type) const;
146 
147   // Get argument data type. If data type is not explicitly defined, uses
148   // provided attribute name to look it up in function attributes.
149   Status GetArgType(const OpDef::ArgDef& arg, DataType* data_type) const;
150 
151  private:
152   const AttrSlice func_instantiation_attr_;  // do not own
153 };
154 
155 // A special case of GrapplerItem, constructed from a TensorFlow Function.
156 class GrapplerFunctionItem : public GrapplerItem {
157  public:
158   GrapplerFunctionItem() = default;
159 
160   const string& description() const;
161 
162   const std::vector<InputArgExpansion>& inputs() const;
163   const InputArgExpansion& input(int i) const;
164   const std::size_t input_size() const;
165 
166   const std::vector<OutputArgExpansion>& outputs() const;
167   const OutputArgExpansion& output(int i) const;
168   const std::size_t output_size() const;
169 
170   const std::vector<ControlOutput>& control_outputs() const;
171   const std::size_t control_output_size() const;
172 
173   const AttrSlice& func_attr() const;
174   const GraphDef& function_body() const;
175   GraphDef& mutable_function_body();
176 
177   bool is_stateful() const;
178 
179   GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
180 
181  private:
182   friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&,
183                                          const FunctionLibraryDefinition&, int,
184                                          GrapplerFunctionItem*);
185   friend Status ReplaceInputWithConst(const NodeDef&, int,
186                                       GrapplerFunctionItem*);
187   friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&,
188                                       GrapplerFunctionItem*,
189                                       std::vector<std::pair<int, int>>*);
190 
191   GrapplerFunctionItem(string func_name, string description,
192                        AttrSlice func_attr,
193                        std::vector<InputArgExpansion> input_arg_expansions,
194                        std::vector<OutputArgExpansion> output_arg_expansions,
195                        std::vector<ControlOutput> control_outputs,
196                        int graph_def_version, bool is_stateful,
197                        GraphDef&& function_body);
198 
199   string description_;
200   AttrSlice func_attr_;  // Attributes specific to function definition that
201                          // produced this item (FuncDef.attr field).
202 
203   std::vector<InputArgExpansion> input_arg_expansions_;
204   std::vector<OutputArgExpansion> output_arg_expansions_;
205   std::vector<ControlOutput> control_outputs_;
206 
207   bool is_stateful_ = false;
208 };
209 
210 // Check if function input/output types are fully defined only at instantiation
211 // time (parametrized by its instantiation node).
212 bool HasParametrizedType(const FunctionDef& func);
213 
214 // Check if a function body is parametrized by its instantiation node. Function
215 // body is parametrized, if it has at least one node with a 'placeholder'
216 // attribute.
217 bool HasParametrizedBody(const FunctionDef& func);
218 
219 // Check if function has parametrized type or body.
220 bool IsParametrized(const FunctionDef& func);
221 
222 // Resolve function instantiation type parameters from the attributes of the
223 // caller node. Return error if type can't be resolved.
224 Status InstantiationTypeParameters(
225     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
226     absl::flat_hash_map<string, DataType>* type_parameters);
227 
228 // Resolve function instantiation body parameters (values for the function body
229 // attr placeholders) from the attributes of the caller node. Return error if
230 // type can't be resolved.
231 Status InstantiationBodyParameters(
232     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
233     absl::flat_hash_map<string, AttrValue>* body_parameters);
234 
235 // Register GrapplerFunctionItem input arg expansion and function body outputs
236 // in the GrapplerFunctionConnectivity. Use function library definition to
237 // lookup function body nodes output names and ranges.
238 Status RegisterGrapplerFunctionConnectivity(
239     const GrapplerFunctionItem& item, const FunctionLibraryDefinition& flib,
240     GrapplerFunctionConnectivity* connectivity);
241 
242 // Replace one of the function inputs with a constant.
243 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
244                              GrapplerFunctionItem* item);
245 
246 // Removes outputs from instantiated grappler function item. Function node
247 // outputs use GraphDef output index encoding, and multiple outputs might belong
248 // to the same output argument expansion (in case of tensor list outputs). For
249 // all active function outputs that changed its output index, this function adds
250 // an output mapping (std::pair<old index, new index>).
251 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
252                              GrapplerFunctionItem* item,
253                              std::vector<std::pair<int, int>>* output_mapping);
254 
255 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs.
256 
257 // Make a GrapplerFunctionItem from the function definition and function
258 // instantiation attributes (caller node attributes). Returns error if the given
259 // function def cannot be converted (e.g. not all attributes are defined).
260 Status MakeGrapplerFunctionItem(const FunctionDef& func,
261                                 const AttrSlice& func_instantiation_attr,
262                                 const FunctionLibraryDefinition& flib,
263                                 int graph_def_version,
264                                 GrapplerFunctionItem* item);
265 
266 // Make a GrapplerFunction item from the function definition. Function must be
267 // fully defined (no type or body parametrization).
268 // TODO(ezhulenev): Support parametrized functions without fully defined
269 // instantiation attributes? Do we ever want to optimize parametrized function
270 // without specializing it to its instantiation attributes (at least types)?
271 Status MakeGrapplerFunctionItem(const FunctionDef& func,
272                                 const FunctionLibraryDefinition& flib,
273                                 int graph_def_version,
274                                 GrapplerFunctionItem* item);
275 
276 // Make a FunctionDef from the GrapplerFunctionItem. Use function library
277 // definition to lookup function body nodes output names and ranges.
278 Status MakeFunctionDef(const GrapplerFunctionItem& item,
279                        const FunctionLibraryDefinition& flib,
280                        FunctionDef* func);
281 
282 }  // end namespace grappler
283 }  // end namespace tensorflow
284 
285 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
286