• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
16 #define TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
17 
18 #include <Python.h>
19 
20 #include <map>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/op_def.pb.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/platform/status.h"
28 #include "tensorflow/python/framework/op_def_util.h"
29 #include "tensorflow/python/framework/python_tensor_converter.h"
30 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
31 
32 namespace tensorflow {
33 
34 // Precomputed information about a TensorFlow Python API.
35 //
36 // PythonAPIInfo records information about a single TensorFlow Python API,
37 // in order to allow calls to the API to be executed more efficiently.  This
38 // information includes:
39 //
40 // * The name of the API.  (E.g. "tf.math.add")
41 //
42 // * The name of the registered op that implements the API, if applicable
43 //   (e.g. "AddV2").
44 //
45 // * Information about the API's parameters.  Parameters are divided into two
46 //   "kinds": inputs and attributes.  An *input* is a parameter that
47 //   expects a Tensor or list of Tensors, and it is described by an `ArgDef`.
48 //   An *attribute* is a parameter that expects any other value type, and it is
49 //   described by an `AttrDef`.
50 //
51 // * Default values for the API's attribute parameters.
52 //
53 // * Information about "inferred attributes" -- attributes whose values are
54 //   inferred from `input` parameters.  There are two kinds of inferred
55 //   attributes: Tensor dtypes, which are inferred from tensor and list(tensor)
56 //   parameters; and list lengths, which are inferred from list(tensor)
57 //   parameters.
58 class PythonAPIInfo {
59  public:
60   // The index of a parameter in the canonicalized parameter list.  The
61   // canonicalized parameter list includes inputs and attributes (but does
62   // not include inferred attributes).  `-1` is used for inferred attributes.
63   using ParamIndex = int;
64 
65   // Information about a parameter that expects a non-Tensor value.
66   struct Attribute {
67     ParamIndex index;  // -1 if this is an inferred attribute
68     AttributeType type;
69     const char* name;    // Interned python string
70     int inferred_index;  // index to store attribute in InferredAttributes
71   };
72 
73   // Information about a parameter that expects a Tensor or list(Tensor).
74   // Additional information about tensor parameters is stored in types
75   // defined below, in order to simplify dtype/length inference:
76   //   * FixedDTypeInput: inputs with fixed dtypes.
77   //   * InputsWithTypeAttr: groups inputs that use a type_attr for dtype.
78   //   * InputsWithTypeListAttr: groups inputs that use a type_list_attr.
79   //   * InputsWithNumberAttr: groups inputs by a number_attr for length.
80   struct Input {
81     ParamIndex index;
82     bool is_list;
83   };
84 
85   // Information about a Tensor parameter w/ fixed dtype.
86   struct InputWithFixedDType {
87     DataType dtype;
88     ParamIndex index;
89     bool is_list;
90   };
91 
92   // Information about Tensor parameters whose DType is specified by a single
93   // `type_attr` attribute.
94   struct InputsWithTypeAttr {
95     Attribute* type_attr;                        // not owned.
96     DataType default_dtype;                      // DT_INVALID if no default.
97     std::vector<ParamIndex> tensor_params;       // single-tensor inputs.
98     std::vector<ParamIndex> tensor_list_params;  // list(tensor) inputs.
99     std::vector<DataType> ok_dtypes;
100   };
101 
102   // Information about Tensor parameters whose DType is specified by a single
103   // `type_list_attr` attribute.
104   struct InputsWithTypeListAttr {
105     Attribute* type_list_attr;                   // not owned.
106     std::vector<DataType> default_dtypes;        // empty if no default.
107     std::vector<ParamIndex> tensor_list_params;  // list(tensor) inputs.
108     std::vector<DataType> ok_dtypes;
109   };
110 
111   // Information about Tensor-list parameters whose length is specified by a
112   // single `int` attribute.
113   struct InputsWithNumberAttr {
114     Attribute* number_attr;                      // not owned.
115     int64_t default_length;                      // -1 for no default.
116     std::vector<ParamIndex> tensor_list_params;  // list(tensor) inputs.
117   };
118 
119   // Structure used to return inferred attribute values.
120   //   * types[i] is the inferred value for inferred_type_attrs()[i]
121   //   * type_lists[i] is the inferred value for inferred_type_list_attrs()[i]
122   //   * lengths[i] is the inferred value for inferred_length_attrs()[i]
123   struct InferredAttributes {
124     std::vector<DataType> types;
125     std::vector<std::vector<DataType>> type_lists;
126     std::vector<int64_t> lengths;
127   };
128 
129   // Constructs a new PythonAPIInfo.
130   //
131   // Note: One of the `Initialize()` functions must be called before the
132   // `PythonAPIInfo` is used.
133   //
134   // Args:
135   //   api_name: The fully-qualified name of the python API (e.g., tf.math.sum).
136   explicit PythonAPIInfo(const std::string& api_name);
137 
138   // Initializes this PythonAPIInfo.
139   //
140   // Args:
141   //   op_def: Contains information about the parameters.
142   //   param_names: The argument names for the python API, in canonical order.
143   //   defaults_tuple: Tuple containing default values for the parameters,
144   //     right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
145   //     for `param_names[-i]`.
146   Status Initialize(const OpDef& op_def, const std::vector<string> param_names,
147                     PyObject* defaults_tuple);
148 
149   // Initialize this PythonAPIInfo based on the registered OpDef for the given
150   // operation.
151   //
152   // Args:
153   //   op_name: The registered name of the operation (e.g. "AddV2").
154   Status InitializeFromRegisteredOp(const std::string& op_name);
155 
156   // Initializes this PythonAPIInfo based on a set of parameter specifications.
157   //
158   // Args:
159   //   input_specs: Mapping from parameter name to specification string for
160   //     each input (parameter that expects a tensor value).
161   //   attr_specs: Mapping from parameter name to specification string for
162   //     each attribute (parameter that expects a non-tensor value).
163   //   param_names: The argument names for the python API, in canonical order.
164   //   defaults_tuple: Tuple containing default values for the parameters,
165   //     right-aligned with `param_names` -- i.e., `defaults[-i]` is the default
166   //     for `param_names[-i]`.
167   //
168   // Note: the `name` parameter should not be included in `input_specs` or
169   // `attr_specs`.
170   Status InitializeFromParamSpecs(
171       const std::map<std::string, std::string>& input_specs,
172       const std::map<std::string, std::string>& attr_specs,
173       const std::vector<string> param_names, PyObject* defaults_tuple);
174 
175   // The name of the API that is described by this PythonAPIInfo.
api_name()176   const char* api_name() const { return api_name_; }
177 
178   // The ordered names of the canononical parameters that this API expects.
param_names()179   const std::vector<const char*>& param_names() const { return param_names_; }
180 
181   // A Python tuple containing the default values for parameters.  This is
182   // right-aligned with `param_name` -- i.e., `defaults[-i]` is the default
183   // for `param_names[-i]`.
defaults_tuple()184   const PyObject* defaults_tuple() const { return defaults_tuple_.get(); }
185 
186   // Information about the attribute (non-tensor) parameters for this API.
attributes()187   const std::vector<Attribute>& attributes() const { return attributes_; }
188 
189   // Information about the input (tensor) parameters for this API.
inputs()190   const std::vector<Input>& inputs() const { return inputs_; }
inputs_with_fixed_dtype()191   const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const {
192     return inputs_with_fixed_dtype_;
193   }
inputs_with_type_attrs()194   const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const {
195     return inputs_with_type_attrs_;
196   }
inputs_with_type_list_attrs()197   const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs()
198       const {
199     return inputs_with_type_list_attrs_;
200   }
inputs_with_number_attrs()201   const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const {
202     return inputs_with_number_attrs_;
203   }
204 
205   // Names of inferred attributes.
inferred_type_attrs()206   const std::vector<const char*>& inferred_type_attrs() const {
207     return inferred_type_attrs_;
208   }
inferred_type_list_attrs()209   const std::vector<const char*>& inferred_type_list_attrs() const {
210     return inferred_type_list_attrs_;
211   }
inferred_length_attrs()212   const std::vector<const char*>& inferred_length_attrs() const {
213     return inferred_length_attrs_;
214   }
215 
216   // Returns a string summarizing the internal state of this type converter.
217   string DebugInfo() const;
218 
219  private:
220   // Adds an entry to the attributes_ vector based on the given `AttrDef`.
221   //
222   // If `attr_def` describes a type attribute, then adds a value to
223   // inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any
224   // tensor inputs that use this dtype).
225   //
226   // If `attr_def` describes an int attribute, then adds a value to
227   // inputs_with_number_attrs_ (to record any tensor inputs that use this
228   // value as a list length).
229   Status InitializeAttribute(
230       const OpDef::AttrDef& attr_def,
231       const std::map<std::string, ParamIndex>& param_name_to_index);
232 
233   // Adds an entry to the inputs_ vector based on the given `ArgDef`.
234   //
235   // If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`.
236   //
237   // If `arg_def`'s dtype is described by a `type` attr, then updates the
238   // appropriate value in `inputs_with_type_attrs_` with information about the
239   // `arg_def`.
240   //
241   // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the
242   // appropriate value in `inputs_with_type_list_attrs_` with information about
243   // the `arg_def`.
244   Status InitializeInput(const OpDef::ArgDef& arg_def,
245                          const std::map<std::string, int>& param_name_to_index);
246 
247   // Checks that the OpDef used to initialize this PythonAPIInfo
248   // had an AttrDef or ArgDef specification for each parameter.
249   Status CheckParamNames() const;
250 
251   // Searches inputs_with_type_attrs_ for an input with the given name.
252   InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name);
253 
254   // Searches inputs_with_type_list_attrs_ for an input with the given name.
255   InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name);
256 
257   // Searches inputs_with_type_list_attrs_ for an input with the given name.
258   InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name);
259 
260   ABSL_MUST_USE_RESULT
261   bool InferLengthAttributes(const absl::Span<PyObject*> params,
262                              std::vector<int64_t>& inferred_length_attrs) const;
263 
264   // ==========================================================================
265   // Member Variables
266   // ==========================================================================
267 
268   // The name of the API that is described by this PythonAPIInfo.
269   // (Interned python string).
270   const char* api_name_;
271 
272   // The names of the parameters that this API expects.
273   // (Interned python strings.)
274   std::vector<const char*> param_names_;
275 
276   // Tuple containing default values for the parameters, right-aligned with
277   // `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`.
278   Safe_PyObjectPtr defaults_tuple_;
279 
280   // Information about the non-tensor-valued parameters that this API expects.
281   std::vector<Attribute> attributes_;
282 
283   // Information about the tensor-valued parameters that this API expects.
284   std::vector<Input> inputs_;
285   std::vector<InputWithFixedDType> inputs_with_fixed_dtype_;
286   std::vector<InputsWithTypeAttr> inputs_with_type_attrs_;
287   std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_;
288   std::vector<InputsWithNumberAttr> inputs_with_number_attrs_;
289 
290   // Names of inferred attributes.  (Interned python strings.)
291   std::vector<const char*> inferred_type_attrs_;
292   std::vector<const char*> inferred_type_list_attrs_;
293   std::vector<const char*> inferred_length_attrs_;
294 };
295 
296 }  // namespace tensorflow
297 
298 #endif  // TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_
299