1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/util/function_parameter_canonicalizer.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/platform/macros.h"
20 #include "tensorflow/python/lib/core/py_util.h"
21 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
22
23 namespace {
PyUnicodeAsUtf8Compat(PyObject * obj)24 inline const char* PyUnicodeAsUtf8Compat(PyObject* obj) {
25 #if PY_MAJOR_VERSION < 3
26 return PyString_AS_STRING(obj);
27 #else
28 return PyUnicode_AsUTF8(obj);
29 #endif
30 }
31
PyUnicodeInternFromStringCompat(const char * str)32 inline PyObject* PyUnicodeInternFromStringCompat(const char* str) {
33 #if PY_MAJOR_VERSION < 3
34 return PyString_InternFromString(str);
35 #else
36 return PyUnicode_InternFromString(str);
37 #endif
38 }
39
PyUnicodeInternInPlaceCompat(PyObject ** obj)40 inline void PyUnicodeInternInPlaceCompat(PyObject** obj) {
41 #if PY_MAJOR_VERSION < 3
42 PyString_InternInPlace(obj);
43 #else
44 PyUnicode_InternInPlace(obj);
45 #endif
46 }
47
48 } // namespace
49
50 namespace tensorflow {
51
FunctionParameterCanonicalizer(absl::Span<const char * > arg_names,absl::Span<PyObject * > defaults)52 FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
53 absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
54 : positional_args_size_(arg_names.size() - defaults.size()) {
55 DCheckPyGilState();
56 DCHECK_GE(positional_args_size_, 0);
57
58 interned_arg_names_.reserve(arg_names.size());
59 for (const char* obj : arg_names)
60 interned_arg_names_.emplace_back(PyUnicodeInternFromStringCompat(obj));
61
62 DCHECK(AreInternedArgNamesUnique());
63
64 for (PyObject* obj : defaults) Py_INCREF(obj);
65 defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
66 }
67
Canonicalize(PyObject * args,PyObject * kwargs,absl::Span<PyObject * > result)68 bool FunctionParameterCanonicalizer::Canonicalize(
69 PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
70 // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
71
72 DCheckPyGilState();
73 DCHECK(PyTuple_CheckExact(args));
74 DCHECK(PyDict_CheckExact(kwargs));
75 DCHECK_EQ(result.size(), interned_arg_names_.size());
76
77 const int args_size = Py_SIZE(args);
78 int remaining_positional_args_count = positional_args_size_ - args_size;
79
80 // Check if the number of input arguments are too many.
81 if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
82 // TODO(kkb): Also report the actual numbers.
83 PyErr_SetString(PyExc_TypeError, "Too many arguments were given");
84 return false;
85 }
86
87 // Fill positional arguments.
88 for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
89
90 // Fill default arguments.
91 for (int i = std::max(positional_args_size_, args_size);
92 i < interned_arg_names_.size(); ++i)
93 result[i] = defaults_[i - positional_args_size_].get();
94
95 // Fill keyword arguments.
96 if (kwargs != nullptr) {
97 PyObject *key, *value;
98 Py_ssize_t pos = 0;
99 while (PyDict_Next(kwargs, &pos, &key, &value)) {
100 std::size_t index = InternedArgNameLinearSearch(key);
101
102 // Check if key object(argument name) was found in the pre-built intern
103 // string table.
104 if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
105 // `key` might not be an interend string, so get the interned string
106 // and try again.
107 PyUnicodeInternInPlaceCompat(&key);
108
109 index = InternedArgNameLinearSearch(key);
110
111 // Stil not found, then return an error.
112 if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
113 PyErr_Format(PyExc_TypeError,
114 "Got an unexpected keyword argument '%s'",
115 PyUnicodeAsUtf8Compat(key));
116 return false;
117 }
118 }
119
120 // Check if the keyword argument overlaps with positional arguments.
121 if (TF_PREDICT_FALSE(index < args_size)) {
122 PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%s'",
123 PyUnicodeAsUtf8Compat(key));
124 return false;
125 }
126
127 if (TF_PREDICT_FALSE(index < positional_args_size_))
128 --remaining_positional_args_count;
129
130 result[index] = value;
131 }
132 }
133
134 // Check if all the arguments are filled.
135 // Example failure, not enough number of arguments passed: `matmul(x)`
136 if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
137 // TODO(kkb): Report what arguments are missing.
138 PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
139 return false;
140 }
141
142 return true;
143 }
144
145 ABSL_MUST_USE_RESULT
146 ABSL_ATTRIBUTE_HOT
InternedArgNameLinearSearch(PyObject * name)147 inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
148 PyObject* name) {
149 std::size_t result = interned_arg_names_.size();
150
151 for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
152 if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
153
154 return result;
155 }
156
AreInternedArgNamesUnique()157 bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
158 absl::flat_hash_set<PyObject*> interned_arg_names_set;
159 for (const Safe_PyObjectPtr& obj : interned_arg_names_)
160 interned_arg_names_set.emplace(obj.get());
161
162 return interned_arg_names_set.size() == interned_arg_names_.size();
163 }
164 } // namespace tensorflow
165