1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_PARAMETER_CONVERTER_H_ 16 #define TENSORFLOW_PYTHON_UTIL_PYTHON_API_PARAMETER_CONVERTER_H_ 17 18 #include <Python.h> 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/core/framework/op_def.pb.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/platform/status.h" 28 #include "tensorflow/python/framework/op_def_util.h" 29 #include "tensorflow/python/framework/python_api_info.h" 30 #include "tensorflow/python/framework/python_tensor_converter.h" 31 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" 32 33 namespace tensorflow { 34 35 // Converts the canoncialized parameters to the expected types (in place). 36 // 37 // * Input parameters (i.e., parameters that expect tensor values) are 38 // converted to tensors (or lists of tensors) using 39 // `tensor_converter.Convert`. 40 // * Attribute parameters are converted to the expected type. 41 // * Inferred attributes are written to `inferred_attrs`. (Can be 42 // nullptr if inferred attributes are not needed.) 43 // * If there's a "name" parameter, then its value is not modified. 44 // 45 // Note: for list-of-tensor parameters, the elements of the list will be 46 // converted in-place. Therefore, any list-of-tensor parameters should have 47 // their values copied to new lists before calling this method. (See 48 // `CopyPythonAPITensorLists`.) 49 // 50 // Any values that are removed from `params` have their reference count 51 // decremented, and any objects added to `params` are new references. 52 // 53 // Returns true on success, or sets an exception and returns false on error. 54 ABSL_MUST_USE_RESULT 55 bool ConvertPythonAPIParameters( 56 const PythonAPIInfo& api_info, 57 const PythonTensorConverter& tensor_converter, absl::Span<PyObject*> params, 58 PythonAPIInfo::InferredAttributes* inferred_attrs); 59 60 // Copies any parameters that expect a list of tensors to a new list. 61 // This ensures that any iterable value can be used, and also ensures that 62 // `ConvertPythonAPIParameters` can safely convert tensors in-place. 63 // 64 // Any values that are removed from `params` have their reference count 65 // decremented, and any objects added to `params` are new references. 66 // 67 // Returns true on success, or sets an exception and returns false on error. 68 ABSL_MUST_USE_RESULT 69 bool CopyPythonAPITensorLists(const PythonAPIInfo& api_info, 70 absl::Span<PyObject*> params); 71 72 } // namespace tensorflow 73 74 #endif // TENSORFLOW_PYTHON_UTIL_PYTHON_API_PARAMETER_CONVERTER_H_ 75