• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/util/function_parameter_canonicalizer.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/core/platform/macros.h"
20 #include "tensorflow/python/lib/core/py_util.h"
21 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
22 
23 namespace {
PyUnicodeAsUtf8Compat(PyObject * obj)24 inline const char* PyUnicodeAsUtf8Compat(PyObject* obj) {
25 #if PY_MAJOR_VERSION < 3
26   return PyString_AS_STRING(obj);
27 #else
28   return PyUnicode_AsUTF8(obj);
29 #endif
30 }
31 
PyUnicodeInternFromStringCompat(const char * str)32 inline PyObject* PyUnicodeInternFromStringCompat(const char* str) {
33 #if PY_MAJOR_VERSION < 3
34   return PyString_InternFromString(str);
35 #else
36   return PyUnicode_InternFromString(str);
37 #endif
38 }
39 
PyUnicodeInternInPlaceCompat(PyObject ** obj)40 inline void PyUnicodeInternInPlaceCompat(PyObject** obj) {
41 #if PY_MAJOR_VERSION < 3
42   PyString_InternInPlace(obj);
43 #else
44   PyUnicode_InternInPlace(obj);
45 #endif
46 }
47 
48 }  // namespace
49 
50 namespace tensorflow {
51 
FunctionParameterCanonicalizer(absl::Span<const char * > arg_names,absl::Span<PyObject * > defaults)52 FunctionParameterCanonicalizer::FunctionParameterCanonicalizer(
53     absl::Span<const char*> arg_names, absl::Span<PyObject*> defaults)
54     : positional_args_size_(arg_names.size() - defaults.size()) {
55   DCheckPyGilState();
56   DCHECK_GE(positional_args_size_, 0);
57 
58   interned_arg_names_.reserve(arg_names.size());
59   for (const char* obj : arg_names)
60     interned_arg_names_.emplace_back(PyUnicodeInternFromStringCompat(obj));
61 
62   DCHECK(AreInternedArgNamesUnique());
63 
64   for (PyObject* obj : defaults) Py_INCREF(obj);
65   defaults_ = std::vector<Safe_PyObjectPtr>(defaults.begin(), defaults.end());
66 }
67 
Canonicalize(PyObject * args,PyObject * kwargs,absl::Span<PyObject * > result)68 bool FunctionParameterCanonicalizer::Canonicalize(
69     PyObject* args, PyObject* kwargs, absl::Span<PyObject*> result) {
70   // TODO(kkb): Closely follow `Python/ceval.c`'s logic and error handling.
71 
72   DCheckPyGilState();
73   DCHECK(PyTuple_CheckExact(args));
74   DCHECK(PyDict_CheckExact(kwargs));
75   DCHECK_EQ(result.size(), interned_arg_names_.size());
76 
77   const int args_size = Py_SIZE(args);
78   int remaining_positional_args_count = positional_args_size_ - args_size;
79 
80   // Check if the number of input arguments are too many.
81   if (TF_PREDICT_FALSE(args_size > interned_arg_names_.size())) {
82     // TODO(kkb): Also report the actual numbers.
83     PyErr_SetString(PyExc_TypeError, "Too many arguments were given");
84     return false;
85   }
86 
87   // Fill positional arguments.
88   for (int i = 0; i < args_size; ++i) result[i] = PyTuple_GET_ITEM(args, i);
89 
90   // Fill default arguments.
91   for (int i = std::max(positional_args_size_, args_size);
92        i < interned_arg_names_.size(); ++i)
93     result[i] = defaults_[i - positional_args_size_].get();
94 
95   // Fill keyword arguments.
96   if (kwargs != nullptr) {
97     PyObject *key, *value;
98     Py_ssize_t pos = 0;
99     while (PyDict_Next(kwargs, &pos, &key, &value)) {
100       std::size_t index = InternedArgNameLinearSearch(key);
101 
102       // Check if key object(argument name) was found in the pre-built intern
103       // string table.
104       if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
105         // `key` might not be an interend string, so get the interned string
106         // and try again.
107         PyUnicodeInternInPlaceCompat(&key);
108 
109         index = InternedArgNameLinearSearch(key);
110 
111         // Stil not found, then return an error.
112         if (TF_PREDICT_FALSE(index == interned_arg_names_.size())) {
113           PyErr_Format(PyExc_TypeError,
114                        "Got an unexpected keyword argument '%s'",
115                        PyUnicodeAsUtf8Compat(key));
116           return false;
117         }
118       }
119 
120       // Check if the keyword argument overlaps with positional arguments.
121       if (TF_PREDICT_FALSE(index < args_size)) {
122         PyErr_Format(PyExc_TypeError, "Got multiple values for argument '%s'",
123                      PyUnicodeAsUtf8Compat(key));
124         return false;
125       }
126 
127       if (TF_PREDICT_FALSE(index < positional_args_size_))
128         --remaining_positional_args_count;
129 
130       result[index] = value;
131     }
132   }
133 
134   // Check if all the arguments are filled.
135   // Example failure, not enough number of arguments passed: `matmul(x)`
136   if (TF_PREDICT_FALSE(remaining_positional_args_count > 0)) {
137     // TODO(kkb): Report what arguments are missing.
138     PyErr_SetString(PyExc_TypeError, "Missing required positional argument");
139     return false;
140   }
141 
142   return true;
143 }
144 
145 ABSL_MUST_USE_RESULT
146 ABSL_ATTRIBUTE_HOT
InternedArgNameLinearSearch(PyObject * name)147 inline std::size_t FunctionParameterCanonicalizer::InternedArgNameLinearSearch(
148     PyObject* name) {
149   std::size_t result = interned_arg_names_.size();
150 
151   for (std::size_t i = 0; i < interned_arg_names_.size(); ++i)
152     if (TF_PREDICT_FALSE(name == interned_arg_names_[i].get())) return i;
153 
154   return result;
155 }
156 
AreInternedArgNamesUnique()157 bool FunctionParameterCanonicalizer::AreInternedArgNamesUnique() {
158   absl::flat_hash_set<PyObject*> interned_arg_names_set;
159   for (const Safe_PyObjectPtr& obj : interned_arg_names_)
160     interned_arg_names_set.emplace(obj.get());
161 
162   return interned_arg_names_set.size() == interned_arg_names_.size();
163 }
164 }  // namespace tensorflow
165