1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_ptr.h"
28
29 namespace tensorflow {
30 namespace swig {
31
PythonTypesMap()32 std::unordered_map<string, PyObject*>* PythonTypesMap() {
33 static auto* m = new std::unordered_map<string, PyObject*>();
34 return m;
35 }
36
GetRegisteredType(const string & key)37 PyObject* GetRegisteredType(const string& key) {
38 auto* m = PythonTypesMap();
39 auto it = m->find(key);
40 if (it == m->end()) return nullptr;
41 return it->second;
42 }
43
RegisterType(PyObject * type_name,PyObject * type)44 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
45 if (!PyType_Check(type)) {
46 PyErr_SetString(PyExc_TypeError,
47 tensorflow::strings::StrCat("Expecting a type, got ",
48 Py_TYPE(type)->tp_name)
49 .c_str());
50 return nullptr;
51 }
52
53 string key;
54 if (PyBytes_Check(type_name)) {
55 key = PyBytes_AsString(type_name);
56 }
57 #if PY_MAJOR_VERSION >= 3
58 if (PyUnicode_Check(type_name)) {
59 key = PyUnicode_AsUTF8(type_name);
60 }
61 #endif
62
63 if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
64 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
65 "Type already registered for ", key)
66 .c_str());
67 return nullptr;
68 }
69
70 Py_INCREF(type);
71 PythonTypesMap()->emplace(key, type);
72
73 Py_RETURN_NONE;
74 }
75
76 namespace {
77 const int kMaxItemsInCache = 1024;
78
79 bool WarnedThatSetIsNotSequence = false;
80
IsString(PyObject * o)81 bool IsString(PyObject* o) {
82 return PyBytes_Check(o) ||
83 #if PY_MAJOR_VERSION < 3
84 PyString_Check(o) ||
85 #endif
86 PyUnicode_Check(o);
87 }
88
89 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
90 // and while we're at it give them consistent behavior by making sure the
91 // returned value is a list.
92 //
93 // As with PyMapping_Keys, returns a new reference.
94 //
95 // On failure, returns nullptr.
MappingKeys(PyObject * o)96 PyObject* MappingKeys(PyObject* o) {
97 #if PY_MAJOR_VERSION >= 3
98 return PyMapping_Keys(o);
99 #else
100 static char key_method_name[] = "keys";
101 Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
102 if (PyErr_Occurred() || raw_result.get() == nullptr) {
103 return nullptr;
104 }
105 return PySequence_Fast(
106 raw_result.get(),
107 "The '.keys()' method of a custom mapping returned a non-sequence.");
108 #endif
109 }
110
111 // Equivalent to Python's 'o.__class__.__name__'
112 // Note that '__class__' attribute is set only in new-style classes.
113 // A lot of tensorflow code uses __class__ without checks, so it seems like
114 // we only support new-style classes.
GetClassName(PyObject * o)115 StringPiece GetClassName(PyObject* o) {
116 // __class__ is equivalent to type() for new style classes.
117 // type() is equivalent to PyObject_Type()
118 // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
119 // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
120 // we don't need here.
121 PyTypeObject* type = o->ob_type;
122
123 // __name__ is the value of `tp_name` after the last '.'
124 // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
125 StringPiece name(type->tp_name);
126 size_t pos = name.rfind('.');
127 if (pos != StringPiece::npos) {
128 name.remove_prefix(pos + 1);
129 }
130 return name;
131 }
132
PyObjectToString(PyObject * o)133 string PyObjectToString(PyObject* o) {
134 if (o == nullptr) {
135 return "<null object>";
136 }
137 PyObject* str = PyObject_Str(o);
138 if (str) {
139 #if PY_MAJOR_VERSION < 3
140 string s(PyString_AS_STRING(str));
141 #else
142 string s(PyUnicode_AsUTF8(str));
143 #endif
144 Py_DECREF(str);
145 return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
146 } else {
147 return "<failed to execute str() on object>";
148 }
149 }
150
151 class CachedTypeCheck {
152 public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)153 explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
154 : ternary_predicate_(std::move(ternary_predicate)) {}
155
~CachedTypeCheck()156 ~CachedTypeCheck() {
157 mutex_lock l(type_to_sequence_map_mu_);
158 for (const auto& pair : type_to_sequence_map_) {
159 Py_DECREF(pair.first);
160 }
161 }
162
163 // Caches successful executions of the one-argument (PyObject*) callable
164 // "ternary_predicate" based on the type of "o". -1 from the callable
165 // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
166 // does not match the predicate, and 1 indicates that it does. Used to avoid
167 // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)168 int CachedLookup(PyObject* o) {
169 // Try not to return to Python - see if the type has already been seen
170 // before.
171
172 auto* type = Py_TYPE(o);
173
174 {
175 tf_shared_lock l(type_to_sequence_map_mu_);
176 auto it = type_to_sequence_map_.find(type);
177 if (it != type_to_sequence_map_.end()) {
178 return it->second;
179 }
180 }
181
182 int check_result = ternary_predicate_(o);
183
184 if (check_result == -1) {
185 return -1; // Type check error, not cached.
186 }
187
188 // NOTE: This is never decref'd as long as the object lives, which is likely
189 // forever, but we don't want the type to get deleted as long as it is in
190 // the map. This should not be too much of a leak, as there should only be a
191 // relatively small number of types in the map, and an even smaller number
192 // that are eligible for decref. As a precaution, we limit the size of the
193 // map to 1024.
194 {
195 mutex_lock l(type_to_sequence_map_mu_);
196 if (type_to_sequence_map_.size() < kMaxItemsInCache) {
197 Py_INCREF(type);
198 auto insert_result = type_to_sequence_map_.insert({type, check_result});
199 if (!insert_result.second) {
200 // The type was added to the cache by a concurrent thread after we
201 // looked it up above.
202 Py_DECREF(type);
203 }
204 }
205 }
206
207 return check_result;
208 }
209
210 private:
211 std::function<int(PyObject*)> ternary_predicate_;
212 mutex type_to_sequence_map_mu_;
213 std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
214 GUARDED_BY(type_to_sequence_map_mu_);
215 };
216
217 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
218 // Returns 0 otherwise.
219 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)220 int IsMappingHelper(PyObject* o) {
221 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
222 PyObject* collections_mapping_type = GetRegisteredType("Mapping");
223 if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
224 PyErr_SetString(PyExc_RuntimeError,
225 tensorflow::strings::StrCat(
226 "collections.Mapping type has not been set. "
227 "Please register the type with the identifier "
228 "\"Mapping\" using RegisterType.")
229 .c_str());
230 return -1;
231 }
232 return PyObject_IsInstance(to_check, collections_mapping_type);
233 });
234 if (PyDict_Check(o)) return true;
235 return check_cache->CachedLookup(o);
236 }
237
238 // Returns 1 if `o` is an instance of attrs-decorated class.
239 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)240 int IsAttrsHelper(PyObject* o) {
241 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
242 Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
243 if (cls) {
244 return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
245 }
246
247 // PyObject_GetAttrString returns null on error
248 PyErr_Clear();
249 return 0;
250 });
251 return check_cache->CachedLookup(o);
252 }
253
254 // Returns 1 if `o` is an object of type IndexedSlices.
255 // Returns 0 otherwise.
256 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)257 int IsIndexedSlicesHelper(PyObject* o) {
258 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
259 PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
260 if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
261 PyErr_SetString(PyExc_RuntimeError,
262 tensorflow::strings::StrCat(
263 "IndexedSlices type has not been set. "
264 "Please register the type with the identifier "
265 "\"IndexedSlices\" using RegisterType.")
266 .c_str());
267 return -1;
268 }
269 return PyObject_IsInstance(to_check, indexed_slices_type);
270 });
271 return check_cache->CachedLookup(o);
272 }
273
274 // Returns 1 if `o` is a Tensor.
275 // Returns 0 otherwise.
276 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)277 int IsTensorHelper(PyObject* o) {
278 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
279 PyObject* tensor_type = GetRegisteredType("Tensor");
280 if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
281 PyErr_SetString(PyExc_RuntimeError,
282 tensorflow::strings::StrCat(
283 "Tensor type has not been set. "
284 "Please register the type with the identifier "
285 "\"Tensor\" using RegisterType.")
286 .c_str());
287 return -1;
288 }
289 return PyObject_IsInstance(to_check, tensor_type);
290 });
291 return check_cache->CachedLookup(o);
292 }
293
294 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
295 // Returns 0 otherwise.
296 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)297 int IsSequenceHelper(PyObject* o) {
298 // We treat dicts and other mappings as special cases of sequences.
299 if (IsMappingHelper(o)) return true;
300 if (IsAttrsHelper(o)) return true;
301 if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
302 LOG(WARNING) << "Sets are not currently considered sequences, "
303 "but this may change in the future, "
304 "so consider avoiding using them.";
305 WarnedThatSetIsNotSequence = true;
306 }
307 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
308 PyObject* collections_sequence_type = GetRegisteredType("Sequence");
309 if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
310 PyErr_SetString(PyExc_RuntimeError,
311 tensorflow::strings::StrCat(
312 "collections.Sequence type has not been set. "
313 "Please register the type with the identifier "
314 "\"Sequence\" using RegisterType.")
315 .c_str());
316 return -1;
317 }
318 int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
319
320 // Don't cache a failed is_instance check.
321 if (is_instance == -1) return -1;
322
323 return static_cast<int>(is_instance != 0 && !IsString(to_check));
324 });
325 return check_cache->CachedLookup(o);
326 }
327
328 // ValueIterator interface
329 class ValueIterator {
330 public:
~ValueIterator()331 virtual ~ValueIterator() {}
332 virtual Safe_PyObjectPtr next() = 0;
333
valid() const334 bool valid() const { return is_valid_; }
335
336 protected:
invalidate()337 void invalidate() { is_valid_ = false; }
338
339 private:
340 bool is_valid_ = true;
341 };
342
343 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
344
345 // Iterate through dictionaries in a deterministic order by sorting the
346 // keys. Notice this means that we ignore the original order of
347 // `OrderedDict` instances. This is intentional, to avoid potential
348 // bugs caused by mixing ordered and plain dicts (e.g., flattening
349 // a dict but using a corresponding `OrderedDict` to pack it back).
350 class DictValueIterator : public ValueIterator {
351 public:
DictValueIterator(PyObject * dict)352 explicit DictValueIterator(PyObject* dict)
353 : dict_(dict), keys_(PyDict_Keys(dict)) {
354 if (PyList_Sort(keys_.get()) == -1) {
355 invalidate();
356 } else {
357 iter_.reset(PyObject_GetIter(keys_.get()));
358 }
359 }
360
next()361 Safe_PyObjectPtr next() override {
362 Safe_PyObjectPtr result;
363 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
364 if (key) {
365 // PyDict_GetItem returns a borrowed reference.
366 PyObject* elem = PyDict_GetItem(dict_, key.get());
367 if (elem) {
368 Py_INCREF(elem);
369 result.reset(elem);
370 } else {
371 PyErr_SetString(PyExc_RuntimeError,
372 "Dictionary was modified during iteration over it");
373 }
374 }
375 return result;
376 }
377
378 private:
379 PyObject* dict_;
380 Safe_PyObjectPtr keys_;
381 Safe_PyObjectPtr iter_;
382 };
383
384 // Iterate over mapping objects by sorting the keys first
385 class MappingValueIterator : public ValueIterator {
386 public:
MappingValueIterator(PyObject * mapping)387 explicit MappingValueIterator(PyObject* mapping)
388 : mapping_(mapping), keys_(MappingKeys(mapping)) {
389 if (!keys_ || PyList_Sort(keys_.get()) == -1) {
390 invalidate();
391 } else {
392 iter_.reset(PyObject_GetIter(keys_.get()));
393 }
394 }
395
next()396 Safe_PyObjectPtr next() override {
397 Safe_PyObjectPtr result;
398 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
399 if (key) {
400 // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
401 PyObject* elem = PyObject_GetItem(mapping_, key.get());
402 if (elem) {
403 result.reset(elem);
404 } else {
405 PyErr_SetString(PyExc_RuntimeError,
406 "Mapping was modified during iteration over it");
407 }
408 }
409 return result;
410 }
411
412 private:
413 PyObject* mapping_;
414 Safe_PyObjectPtr keys_;
415 Safe_PyObjectPtr iter_;
416 };
417
418 // Iterate over a sequence, by index.
419 class SequenceValueIterator : public ValueIterator {
420 public:
SequenceValueIterator(PyObject * iterable)421 explicit SequenceValueIterator(PyObject* iterable)
422 : seq_(PySequence_Fast(iterable, "")),
423 size_(PySequence_Fast_GET_SIZE(seq_.get())),
424 index_(0) {}
425
next()426 Safe_PyObjectPtr next() override {
427 Safe_PyObjectPtr result;
428 if (index_ < size_) {
429 // PySequence_Fast_GET_ITEM returns a borrowed reference.
430 PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
431 ++index_;
432 Py_INCREF(elem);
433 result.reset(elem);
434 }
435
436 return result;
437 }
438
439 private:
440 Safe_PyObjectPtr seq_;
441 const Py_ssize_t size_;
442 Py_ssize_t index_;
443 };
444
445 // Just return itself as a single item.
446 class SparseTensorValueIterator : public ValueIterator {
447 public:
SparseTensorValueIterator(PyObject * tensor)448 explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
449 Py_INCREF(tensor);
450 }
451
next()452 Safe_PyObjectPtr next() override { return std::move(tensor_); }
453
454 private:
455 Safe_PyObjectPtr tensor_;
456 };
457
458 // Returns nullptr (to raise an exception) when next() is called. Caller
459 // should have already called PyErr_SetString.
460 class ErrorValueIterator : public ValueIterator {
461 public:
ErrorValueIterator()462 ErrorValueIterator() {}
next()463 Safe_PyObjectPtr next() override { return nullptr; }
464 };
465
466 class AttrsValueIterator : public ValueIterator {
467 public:
AttrsValueIterator(PyObject * nested)468 explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
469 Py_INCREF(nested);
470 cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
471 if (cls_) {
472 attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
473 if (attrs_) {
474 iter_.reset(PyObject_GetIter(attrs_.get()));
475 }
476 }
477 if (!iter_ || PyErr_Occurred()) invalidate();
478 }
479
next()480 Safe_PyObjectPtr next() override {
481 Safe_PyObjectPtr result;
482 Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
483 if (item) {
484 Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
485 result.reset(PyObject_GetAttr(nested_.get(), name.get()));
486 }
487
488 return result;
489 }
490
491 private:
492 Safe_PyObjectPtr nested_;
493 Safe_PyObjectPtr cls_;
494 Safe_PyObjectPtr attrs_;
495 Safe_PyObjectPtr iter_;
496 };
497
IsSparseTensorValueType(PyObject * o)498 bool IsSparseTensorValueType(PyObject* o) {
499 PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
500 if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
501 return false;
502 }
503
504 return PyObject_TypeCheck(
505 o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
506 }
507
508 // Returns 1 if `o` is an instance of CompositeTensor.
509 // Returns 0 otherwise.
510 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)511 bool IsCompositeTensorHelper(PyObject* o) {
512 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
513 PyObject* composite_tensor_type = GetRegisteredType("CompositeTensor");
514 if (TF_PREDICT_FALSE(composite_tensor_type == nullptr)) {
515 PyErr_SetString(PyExc_RuntimeError,
516 tensorflow::strings::StrCat(
517 "CompositeTensor type has not been set. "
518 "Please register the type with the identifier "
519 "\"CompositeTensor\" using RegisterType.")
520 .c_str());
521 return -1;
522 }
523 int is_instance = PyObject_IsInstance(to_check, composite_tensor_type);
524
525 // Don't cache a failed is_instance check.
526 if (is_instance == -1) return -1;
527
528 return static_cast<int>(is_instance != 0);
529 });
530 return check_cache->CachedLookup(o);
531 }
532
IsSequenceOrCompositeHelper(PyObject * o)533 int IsSequenceOrCompositeHelper(PyObject* o) {
534 return IsSequence(o) || IsCompositeTensor(o);
535 }
536
IsSequenceForDataHelper(PyObject * o)537 int IsSequenceForDataHelper(PyObject* o) {
538 return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
539 !IsSparseTensorValueType(o);
540 }
541
GetValueIterator(PyObject * nested)542 ValueIteratorPtr GetValueIterator(PyObject* nested) {
543 if (PyDict_Check(nested)) {
544 return absl::make_unique<DictValueIterator>(nested);
545 } else if (IsMappingHelper(nested)) {
546 return absl::make_unique<MappingValueIterator>(nested);
547 } else if (IsAttrsHelper(nested)) {
548 return absl::make_unique<AttrsValueIterator>(nested);
549 } else {
550 return absl::make_unique<SequenceValueIterator>(nested);
551 }
552 }
553
554 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)555 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
556 if (PyDict_Check(nested)) {
557 return absl::make_unique<DictValueIterator>(nested);
558 } else if (IsMappingHelper(nested)) {
559 return absl::make_unique<MappingValueIterator>(nested);
560 } else if (IsAttrsHelper(nested)) {
561 return absl::make_unique<AttrsValueIterator>(nested);
562 } else if (IsSparseTensorValueType(nested)) {
563 return absl::make_unique<SparseTensorValueIterator>(nested);
564 } else {
565 return absl::make_unique<SequenceValueIterator>(nested);
566 }
567 }
568
569 // Similar to GetValueIterator above, but expands CompositeTensors.
GetValueIteratorForComposite(PyObject * nested)570 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
571 if (IsCompositeTensor(nested)) {
572 static char expand_method_name[] = "_to_components";
573 nested = PyObject_CallMethod(nested, expand_method_name, nullptr);
574 if (PyErr_Occurred() || nested == nullptr) {
575 return absl::make_unique<ErrorValueIterator>();
576 }
577 }
578 return GetValueIterator(nested);
579 }
580
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)581 bool FlattenHelper(
582 PyObject* nested, PyObject* list,
583 const std::function<int(PyObject*)>& is_sequence_helper,
584 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
585 // if nested is not a sequence, append itself and exit
586 int is_seq = is_sequence_helper(nested);
587 if (is_seq == -1) return false;
588 if (!is_seq) {
589 return PyList_Append(list, nested) != -1;
590 }
591
592 ValueIteratorPtr iter = value_iterator_getter(nested);
593 if (!iter->valid()) return false;
594
595 for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
596 if (Py_EnterRecursiveCall(" in flatten")) {
597 return false;
598 }
599 const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
600 value_iterator_getter);
601 Py_LeaveRecursiveCall();
602 if (!success) {
603 return false;
604 }
605 }
606 return true;
607 }
608
609 // Sets error using keys of 'dict1' and 'dict2'.
610 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)611 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
612 bool* is_type_error) {
613 Safe_PyObjectPtr k1(MappingKeys(dict1));
614 if (PyErr_Occurred() || k1.get() == nullptr) {
615 *error_msg =
616 ("The two dictionaries don't have the same set of keys. Failed to "
617 "fetch keys.");
618 return;
619 }
620 Safe_PyObjectPtr k2(MappingKeys(dict2));
621 if (PyErr_Occurred() || k2.get() == nullptr) {
622 *error_msg =
623 ("The two dictionaries don't have the same set of keys. Failed to "
624 "fetch keys.");
625 return;
626 }
627 *is_type_error = false;
628 *error_msg = tensorflow::strings::StrCat(
629 "The two dictionaries don't have the same set of keys. "
630 "First structure has keys ",
631 PyObjectToString(k1.get()), ", while second structure has keys ",
632 PyObjectToString(k2.get()));
633 }
634
635 // Returns true iff there were no "internal" errors. In other words,
636 // errors that has nothing to do with structure checking.
637 // If an "internal" error occurred, the appropriate Python error will be
638 // set and the caller can propage it directly to the user.
639 //
640 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
641 // be empty.
642 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
643 // with appropriate error and sets `is_type_error` to true iff
644 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)645 bool AssertSameStructureHelper(
646 PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
647 bool* is_type_error,
648 const std::function<int(PyObject*)>& is_sequence_helper,
649 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
650 DCHECK(error_msg);
651 DCHECK(is_type_error);
652 const bool is_seq1 = is_sequence_helper(o1);
653 const bool is_seq2 = is_sequence_helper(o2);
654 if (PyErr_Occurred()) return false;
655 if (is_seq1 != is_seq2) {
656 string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
657 string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
658 *is_type_error = false;
659 *error_msg = tensorflow::strings::StrCat(
660 "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
661 non_seq_str, "\" is not");
662 return true;
663 }
664
665 // Got to objects that are considered non-sequences. Note that in tf.data
666 // use case lists and sparse_tensors are not considered sequences. So finished
667 // checking, structures are the same.
668 if (!is_seq1) return true;
669
670 if (check_types) {
671 const PyTypeObject* type1 = o1->ob_type;
672 const PyTypeObject* type2 = o2->ob_type;
673
674 // We treat two different namedtuples with identical name and fields
675 // as having the same type.
676 const PyObject* o1_tuple = IsNamedtuple(o1, true);
677 if (o1_tuple == nullptr) return false;
678 const PyObject* o2_tuple = IsNamedtuple(o2, true);
679 if (o2_tuple == nullptr) {
680 Py_DECREF(o1_tuple);
681 return false;
682 }
683 bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
684 Py_DECREF(o1_tuple);
685 Py_DECREF(o2_tuple);
686
687 if (both_tuples) {
688 const PyObject* same_tuples = SameNamedtuples(o1, o2);
689 if (same_tuples == nullptr) return false;
690 bool not_same_tuples = same_tuples != Py_True;
691 Py_DECREF(same_tuples);
692 if (not_same_tuples) {
693 *is_type_error = true;
694 *error_msg = tensorflow::strings::StrCat(
695 "The two namedtuples don't have the same sequence type. "
696 "First structure ",
697 PyObjectToString(o1), " has type ", type1->tp_name,
698 ", while second structure ", PyObjectToString(o2), " has type ",
699 type2->tp_name);
700 return true;
701 }
702 } else if (type1 != type2
703 /* If both sequences are list types, don't complain. This allows
704 one to be a list subclass (e.g. _ListWrapper used for
705 automatic dependency tracking.) */
706 && !(PyList_Check(o1) && PyList_Check(o2))
707 /* Two mapping types will also compare equal, making _DictWrapper
708 and dict compare equal. */
709 && !(IsMappingHelper(o1) && IsMappingHelper(o2))) {
710 *is_type_error = true;
711 *error_msg = tensorflow::strings::StrCat(
712 "The two namedtuples don't have the same sequence type. "
713 "First structure ",
714 PyObjectToString(o1), " has type ", type1->tp_name,
715 ", while second structure ", PyObjectToString(o2), " has type ",
716 type2->tp_name);
717 return true;
718 }
719
720 if (PyDict_Check(o1) && PyDict_Check(o2)) {
721 if (PyDict_Size(o1) != PyDict_Size(o2)) {
722 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
723 return true;
724 }
725
726 PyObject* key;
727 Py_ssize_t pos = 0;
728 while (PyDict_Next(o1, &pos, &key, nullptr)) {
729 if (PyDict_GetItem(o2, key) == nullptr) {
730 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
731 return true;
732 }
733 }
734 } else if (IsMappingHelper(o1)) {
735 // Fallback for custom mapping types. Instead of using PyDict methods
736 // which stay in C, we call iter(o1).
737 if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
738 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
739 return true;
740 }
741
742 Safe_PyObjectPtr iter(PyObject_GetIter(o1));
743 PyObject* key;
744 while ((key = PyIter_Next(iter.get())) != nullptr) {
745 if (!PyMapping_HasKey(o2, key)) {
746 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
747 Py_DECREF(key);
748 return true;
749 }
750 Py_DECREF(key);
751 }
752 }
753 }
754
755 ValueIteratorPtr iter1 = value_iterator_getter(o1);
756 ValueIteratorPtr iter2 = value_iterator_getter(o2);
757
758 if (!iter1->valid() || !iter2->valid()) return false;
759
760 while (true) {
761 Safe_PyObjectPtr v1 = iter1->next();
762 Safe_PyObjectPtr v2 = iter2->next();
763 if (v1 && v2) {
764 if (Py_EnterRecursiveCall(" in assert_same_structure")) {
765 return false;
766 }
767 bool no_internal_errors = AssertSameStructureHelper(
768 v1.get(), v2.get(), check_types, error_msg, is_type_error,
769 is_sequence_helper, value_iterator_getter);
770 Py_LeaveRecursiveCall();
771 if (!no_internal_errors) return false;
772 if (!error_msg->empty()) return true;
773 } else if (!v1 && !v2) {
774 // Done with all recursive calls. Structure matched.
775 return true;
776 } else {
777 *is_type_error = false;
778 *error_msg = tensorflow::strings::StrCat(
779 "The two structures don't have the same number of elements. ",
780 "First structure: ", PyObjectToString(o1),
781 ". Second structure: ", PyObjectToString(o2));
782 return true;
783 }
784 }
785 }
786
787 } // namespace
788
IsSequence(PyObject * o)789 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)790 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsAttrs(PyObject * o)791 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)792 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsIndexedSlices(PyObject * o)793 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
794
Flatten(PyObject * nested,bool expand_composites)795 PyObject* Flatten(PyObject* nested, bool expand_composites) {
796 PyObject* list = PyList_New(0);
797 const std::function<int(PyObject*)>& is_sequence_helper =
798 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
799 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
800 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
801 if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
802 return list;
803 } else {
804 Py_DECREF(list);
805 return nullptr;
806 }
807 }
808
IsSequenceOrComposite(PyObject * o)809 bool IsSequenceOrComposite(PyObject* o) {
810 return IsSequenceOrCompositeHelper(o) == 1;
811 }
812
IsCompositeTensor(PyObject * o)813 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
814
IsSequenceForData(PyObject * o)815 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
816
FlattenForData(PyObject * nested)817 PyObject* FlattenForData(PyObject* nested) {
818 PyObject* list = PyList_New(0);
819 if (FlattenHelper(nested, list, IsSequenceForDataHelper,
820 GetValueIteratorForData)) {
821 return list;
822 } else {
823 Py_DECREF(list);
824 return nullptr;
825 }
826 }
827
IsNamedtuple(PyObject * o,bool strict)828 PyObject* IsNamedtuple(PyObject* o, bool strict) {
829 // Must be subclass of tuple
830 if (!PyTuple_Check(o)) {
831 Py_RETURN_FALSE;
832 }
833
834 // If strict, o.__class__.__base__ must be tuple
835 if (strict) {
836 PyObject* klass = PyObject_GetAttrString(o, "__class__");
837 if (klass == nullptr) return nullptr;
838 PyObject* base = PyObject_GetAttrString(klass, "__base__");
839 Py_DECREF(klass);
840 if (base == nullptr) return nullptr;
841
842 const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
843 // built-in object types are singletons
844 bool tuple_base = base_type == &PyTuple_Type;
845 Py_DECREF(base);
846 if (!tuple_base) {
847 Py_RETURN_FALSE;
848 }
849 }
850
851 PyObject* collections_sequence_type = GetRegisteredType("Sequence");
852
853 if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
854 PyErr_SetString(PyExc_RuntimeError,
855 tensorflow::strings::StrCat(
856 "collections.Sequence type has not been set. "
857 "Please register the type with the identifier "
858 "\"Sequence\" using RegisterType.")
859 .c_str());
860 return nullptr;
861 }
862
863 // o must have attribute '_fields' and every element in
864 // '_fields' must be a string.
865 int has_fields = PyObject_HasAttrString(o, "_fields");
866 if (!has_fields) {
867 Py_RETURN_FALSE;
868 }
869
870 Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
871 int is_instance =
872 PyObject_IsInstance(fields.get(), collections_sequence_type);
873 if (is_instance == 0) {
874 Py_RETURN_FALSE;
875 } else if (is_instance == -1) {
876 return nullptr;
877 }
878
879 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
880 const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
881 for (Py_ssize_t i = 0; i < s; ++i) {
882 // PySequence_Fast_GET_ITEM returns borrowed ref
883 PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
884 if (!IsString(elem)) {
885 Py_RETURN_FALSE;
886 }
887 }
888
889 Py_RETURN_TRUE;
890 }
891
SameNamedtuples(PyObject * o1,PyObject * o2)892 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
893 Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
894 Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
895 if (f1 == nullptr || f2 == nullptr) {
896 PyErr_SetString(
897 PyExc_RuntimeError,
898 "Expected namedtuple-like objects (that have _fields attr)");
899 return nullptr;
900 }
901
902 if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
903 Py_RETURN_FALSE;
904 }
905
906 if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
907 Py_RETURN_TRUE;
908 } else {
909 Py_RETURN_FALSE;
910 }
911 }
912
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)913 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
914 bool expand_composites) {
915 const std::function<int(PyObject*)>& is_sequence_helper =
916 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
917 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
918 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
919 string error_msg;
920 bool is_type_error = false;
921 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
922 is_sequence_helper, get_value_iterator);
923 if (PyErr_Occurred()) {
924 // Don't hide Python exceptions while checking (e.g. errors fetching keys
925 // from custom mappings).
926 return nullptr;
927 }
928 if (!error_msg.empty()) {
929 PyErr_SetString(
930 is_type_error ? PyExc_TypeError : PyExc_ValueError,
931 tensorflow::strings::StrCat(
932 "The two structures don't have the same nested structure.\n\n",
933 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
934 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
935 .c_str());
936 return nullptr;
937 }
938 Py_RETURN_NONE;
939 }
940
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)941 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
942 bool check_types) {
943 string error_msg;
944 bool is_type_error = false;
945 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
946 IsSequenceForDataHelper, GetValueIterator);
947 if (PyErr_Occurred()) {
948 // Don't hide Python exceptions while checking (e.g. errors fetching keys
949 // from custom mappings).
950 return nullptr;
951 }
952 if (!error_msg.empty()) {
953 PyErr_SetString(
954 is_type_error ? PyExc_TypeError : PyExc_ValueError,
955 tensorflow::strings::StrCat(
956 "The two structures don't have the same nested structure.\n\n",
957 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
958 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
959 .c_str());
960 return nullptr;
961 }
962 Py_RETURN_NONE;
963 }
964
965 } // namespace swig
966 } // namespace tensorflow
967