• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
18 
19 #include <memory>
20 #include <string>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/container/inlined_vector.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/function.pb.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/grappler/grappler_item.h"
31 #include "tensorflow/core/lib/gtl/flatset.h"
32 
33 namespace tensorflow {
34 namespace grappler {
35 
36 // Function input argument instantiated into an '_Arg' node in the function body
37 // graph, with an 'index' attribute corresponding to the input position.
38 struct InputArgInstantiation {
InputArgInstantiationInputArgInstantiation39   InputArgInstantiation(string node_name, DataType data_type)
40       : node_name(std::move(node_name)), data_type(data_type) {}
41   string node_name;
42   DataType data_type;
43 };
44 
45 // Function output instantiated into a '_Retval' node in the function body
46 // graph, with an 'index' attribute corresponding to the output position.
47 struct OutputArgInstantiation {
OutputArgInstantiationOutputArgInstantiation48   OutputArgInstantiation(string node_name, DataType data_type)
49       : node_name(std::move(node_name)), data_type(data_type) {}
50   string node_name;
51   DataType data_type;
52 };
53 
54 // A mapping from control output name to node name in function body graph.
55 struct ControlOutput {
56   string output_name;
57   string node_name;
58 };
59 
60 // A special case of GrapplerItem, constructed from a TensorFlow Function.
61 class GrapplerFunctionItem : public GrapplerItem {
62  public:
63   GrapplerFunctionItem() = default;
64 
65   const string& description() const;
66 
67   const std::vector<InputArgInstantiation>& inputs() const;
68   const InputArgInstantiation& input(int i) const;
69   const std::size_t input_size() const;
70 
71   const std::vector<OutputArgInstantiation>& outputs() const;
72   const OutputArgInstantiation& output(int i) const;
73   const std::size_t output_size() const;
74 
75   const std::vector<ControlOutput>& control_outputs() const;
76   const std::size_t control_output_size() const;
77 
78   const AttrSlice& func_attr() const;
79   const std::vector<const FunctionDef::ArgAttrs*>& arg_attr() const;
80   const GraphDef& function_body() const;
81   GraphDef& mutable_function_body();
82 
83   bool is_stateful() const;
84 
85   GrapplerFunctionItem& SwapFunctionBody(GraphDef&& other);
86 
87  private:
88   friend Status MakeGrapplerFunctionItem(const FunctionDef&, const AttrSlice&,
89                                          const FunctionLibraryDefinition&, int,
90                                          GrapplerFunctionItem*);
91   friend Status ReplaceInputWithConst(const NodeDef&, int,
92                                       GrapplerFunctionItem*);
93   friend Status RemoveFunctionOutputs(const absl::flat_hash_set<int>&,
94                                       GrapplerFunctionItem*,
95                                       std::vector<std::pair<int, int>>*);
96 
97   GrapplerFunctionItem(string func_name, string description,
98                        AttrSlice func_attr,
99                        std::vector<const FunctionDef::ArgAttrs*> arg_attr,
100                        std::vector<InputArgInstantiation> input_args,
101                        std::vector<OutputArgInstantiation> output_args,
102                        std::vector<ControlOutput> control_outputs,
103                        int graph_def_version, bool is_stateful,
104                        GraphDef&& function_body);
105 
106   string description_;
107   AttrSlice func_attr_;  // Attributes specific to function definition that
108                          // produced this item (FuncDef.attr field).
109 
110   // Attributes of function arguments
111   std::vector<const FunctionDef::ArgAttrs*> arg_attr_;
112 
113   std::vector<InputArgInstantiation> input_args_;
114   std::vector<OutputArgInstantiation> output_args_;
115   std::vector<ControlOutput> control_outputs_;
116 
117   bool is_stateful_ = false;
118 };
119 
120 // Check if function input/output types are fully defined only at instantiation
121 // time (parametrized by its instantiation node).
122 bool HasParametrizedType(const FunctionDef& func);
123 
124 // Check if a function body is parametrized by its instantiation node. Function
125 // body is parametrized, if it has at least one node with a 'placeholder'
126 // attribute.
127 bool HasParametrizedBody(const FunctionDef& func);
128 
129 // Check if function has parametrized type or body.
130 bool IsParametrized(const FunctionDef& func);
131 
132 // Resolve function instantiation type parameters from the attributes of the
133 // caller node. Return error if type can't be resolved.
134 Status InstantiationTypeParameters(
135     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
136     absl::flat_hash_map<string, DataType>* type_parameters);
137 
138 // Resolve function instantiation body parameters (values for the function body
139 // attr placeholders) from the attributes of the caller node. Return error if
140 // type can't be resolved.
141 Status InstantiationBodyParameters(
142     const FunctionDef& func, const AttrSlice& func_instantiation_attr,
143     absl::flat_hash_map<string, AttrValue>* body_parameters);
144 
145 // Replace one of the function inputs with a constant.
146 Status ReplaceInputWithConst(const NodeDef& input_const, int input_index,
147                              GrapplerFunctionItem* item);
148 
149 // Removes outputs from instantiated grappler function item. For all active
150 // function outputs that changed its output index, this function adds an output
151 // mapping (std::pair<old index, new index>).
152 Status RemoveFunctionOutputs(const absl::flat_hash_set<int>& remove_outputs,
153                              GrapplerFunctionItem* item,
154                              std::vector<std::pair<int, int>>* output_mapping);
155 
156 // TODO(ezhulenev, b/120103818): Add RemoveFunctionInputs.
157 
158 // Make a GrapplerFunctionItem from the function definition and function
159 // instantiation attributes (caller node attributes). Returns error if the given
160 // function def cannot be converted (e.g. not all attributes are defined).
161 Status MakeGrapplerFunctionItem(const FunctionDef& func,
162                                 const AttrSlice& func_instantiation_attr,
163                                 const FunctionLibraryDefinition& flib,
164                                 int graph_def_version,
165                                 GrapplerFunctionItem* item);
166 
167 // Make a GrapplerFunction item from the function definition. Function must be
168 // fully defined (no type or body parametrization).
169 // TODO(ezhulenev): Support parametrized functions without fully defined
170 // instantiation attributes? Do we ever want to optimize parametrized function
171 // without specializing it to its instantiation attributes (at least types)?
172 Status MakeGrapplerFunctionItem(const FunctionDef& func,
173                                 const FunctionLibraryDefinition& flib,
174                                 int graph_def_version,
175                                 GrapplerFunctionItem* item);
176 
177 // Make a FunctionDef from the GrapplerFunctionItem. Use function library
178 // definition to lookup function body nodes output names and ranges.
179 Status MakeFunctionDef(const GrapplerFunctionItem& item,
180                        const FunctionLibraryDefinition& flib,
181                        FunctionDef* func);
182 
183 }  // end namespace grappler
184 }  // end namespace tensorflow
185 
186 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_FUNCTIONS_H_
187