1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28
29 namespace tensorflow {
30 namespace swig {
31
32 namespace {
33 string PyObjectToString(PyObject* o);
34 } // namespace
35
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37 static auto* m = new std::unordered_map<string, PyObject*>();
38 return m;
39 }
40
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42 const auto* m = RegisteredPyObjectMap();
43 auto it = m->find(name);
44 if (it == m->end()) {
45 PyErr_SetString(PyExc_TypeError,
46 tensorflow::strings::StrCat("No object with name ", name,
47 " has been registered.")
48 .c_str());
49 return nullptr;
50 }
51 return it->second;
52 }
53
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55 if (!PyType_Check(type)) {
56 PyErr_SetString(PyExc_TypeError,
57 tensorflow::strings::StrCat("Expecting a type, got ",
58 Py_TYPE(type)->tp_name)
59 .c_str());
60 return nullptr;
61 }
62 return RegisterPyObject(type_name, type);
63 }
64
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66 string key;
67 if (PyBytes_Check(name)) {
68 key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70 } else if (PyUnicode_Check(name)) {
71 key = PyUnicode_AsUTF8(name);
72 #endif
73 } else {
74 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75 "Expected name to be a str, got",
76 PyObjectToString(name))
77 .c_str());
78 return nullptr;
79 }
80
81 auto* m = RegisteredPyObjectMap();
82 if (m->find(key) != m->end()) {
83 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84 "Value already registered for ", key)
85 .c_str());
86 return nullptr;
87 }
88
89 Py_INCREF(value);
90 m->emplace(key, value);
91
92 Py_RETURN_NONE;
93 }
94
95 namespace {
96 const int kMaxItemsInCache = 1024;
97
98 bool WarnedThatSetIsNotSequence = false;
99
IsString(PyObject * o)100 bool IsString(PyObject* o) {
101 return PyBytes_Check(o) ||
102 #if PY_MAJOR_VERSION < 3
103 PyString_Check(o) ||
104 #endif
105 PyUnicode_Check(o);
106 }
107
108 // Equivalent to Python's 'o.__class__.__name__'
109 // Note that '__class__' attribute is set only in new-style classes.
110 // A lot of tensorflow code uses __class__ without checks, so it seems like
111 // we only support new-style classes.
GetClassName(PyObject * o)112 StringPiece GetClassName(PyObject* o) {
113 // __class__ is equivalent to type() for new style classes.
114 // type() is equivalent to PyObject_Type()
115 // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
116 // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
117 // we don't need here.
118 PyTypeObject* type = o->ob_type;
119
120 // __name__ is the value of `tp_name` after the last '.'
121 // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
122 StringPiece name(type->tp_name);
123 size_t pos = name.rfind('.');
124 if (pos != StringPiece::npos) {
125 name.remove_prefix(pos + 1);
126 }
127 return name;
128 }
129
PyObjectToString(PyObject * o)130 string PyObjectToString(PyObject* o) {
131 if (o == nullptr) {
132 return "<null object>";
133 }
134 PyObject* str = PyObject_Str(o);
135 if (str) {
136 #if PY_MAJOR_VERSION < 3
137 string s(PyString_AS_STRING(str));
138 #else
139 string s(PyUnicode_AsUTF8(str));
140 #endif
141 Py_DECREF(str);
142 return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
143 } else {
144 return "<failed to execute str() on object>";
145 }
146 }
147
148 class CachedTypeCheck {
149 public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)150 explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
151 : ternary_predicate_(std::move(ternary_predicate)) {}
152
~CachedTypeCheck()153 ~CachedTypeCheck() {
154 mutex_lock l(type_to_sequence_map_mu_);
155 for (const auto& pair : type_to_sequence_map_) {
156 Py_DECREF(pair.first);
157 }
158 }
159
160 // Caches successful executions of the one-argument (PyObject*) callable
161 // "ternary_predicate" based on the type of "o". -1 from the callable
162 // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
163 // does not match the predicate, and 1 indicates that it does. Used to avoid
164 // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)165 int CachedLookup(PyObject* o) {
166 // Try not to return to Python - see if the type has already been seen
167 // before.
168
169 auto* type = Py_TYPE(o);
170
171 {
172 tf_shared_lock l(type_to_sequence_map_mu_);
173 auto it = type_to_sequence_map_.find(type);
174 if (it != type_to_sequence_map_.end()) {
175 return it->second;
176 }
177 }
178
179 int check_result = ternary_predicate_(o);
180
181 if (check_result == -1) {
182 return -1; // Type check error, not cached.
183 }
184
185 // NOTE: This is never decref'd as long as the object lives, which is likely
186 // forever, but we don't want the type to get deleted as long as it is in
187 // the map. This should not be too much of a leak, as there should only be a
188 // relatively small number of types in the map, and an even smaller number
189 // that are eligible for decref. As a precaution, we limit the size of the
190 // map to 1024.
191 {
192 mutex_lock l(type_to_sequence_map_mu_);
193 if (type_to_sequence_map_.size() < kMaxItemsInCache) {
194 Py_INCREF(type);
195 auto insert_result = type_to_sequence_map_.insert({type, check_result});
196 if (!insert_result.second) {
197 // The type was added to the cache by a concurrent thread after we
198 // looked it up above.
199 Py_DECREF(type);
200 }
201 }
202 }
203
204 return check_result;
205 }
206
207 private:
208 std::function<int(PyObject*)> ternary_predicate_;
209 mutex type_to_sequence_map_mu_;
210 std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
211 TF_GUARDED_BY(type_to_sequence_map_mu_);
212 };
213
214 // Returns 1 if 'obj' is an instance of 'type_name'
215 // Returns 0 otherwise.
216 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)217 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
218 PyObject* type_obj = GetRegisteredPyObject(type_name);
219 if (TF_PREDICT_FALSE(type_obj == nullptr)) {
220 PyErr_SetString(PyExc_RuntimeError,
221 tensorflow::strings::StrCat(
222 type_name,
223 " type has not been set. "
224 "Please register the type with the identifier \"",
225 type_name, "\" using RegisterType.")
226 .c_str());
227 return -1;
228 }
229 return PyObject_IsInstance(obj, type_obj);
230 }
231
232 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
233 // Returns 0 otherwise.
234 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)235 int IsMappingHelper(PyObject* o) {
236 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
237 return IsInstanceOfRegisteredType(to_check, "Mapping");
238 });
239 if (PyDict_Check(o)) return true;
240 return check_cache->CachedLookup(o);
241 }
242
243 // Returns 1 if `o` is considered a mutable mapping for the purposes of
244 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)245 int IsMutableMappingHelper(PyObject* o) {
246 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
247 return IsInstanceOfRegisteredType(to_check, "MutableMapping");
248 });
249 if (PyDict_Check(o)) return true;
250 return check_cache->CachedLookup(o);
251 }
252
253 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
254 // Returns 0 otherwise.
255 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)256 int IsMappingViewHelper(PyObject* o) {
257 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
258 return IsInstanceOfRegisteredType(to_check, "MappingView");
259 });
260 return check_cache->CachedLookup(o);
261 }
262
263 // Returns 1 if `o` is considered an object proxy
264 // Returns 0 otherwise.
265 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)266 int IsObjectProxy(PyObject* o) {
267 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
268 return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
269 });
270 return check_cache->CachedLookup(o);
271 }
272
273 // Returns 1 if `o` is an instance of attrs-decorated class.
274 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)275 int IsAttrsHelper(PyObject* o) {
276 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
277 Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
278 if (cls) {
279 return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
280 }
281
282 // PyObject_GetAttrString returns null on error
283 PyErr_Clear();
284 return 0;
285 });
286 return check_cache->CachedLookup(o);
287 }
288
289 // Returns 1 if `o` is an object of type IndexedSlices.
290 // Returns 0 otherwise.
291 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)292 int IsIndexedSlicesHelper(PyObject* o) {
293 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
294 return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
295 });
296 return check_cache->CachedLookup(o);
297 }
298
299 // Returns 1 if `o` is a Tensor.
300 // Returns 0 otherwise.
301 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)302 int IsTensorHelper(PyObject* o) {
303 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
304 return IsInstanceOfRegisteredType(to_check, "Tensor");
305 });
306 return check_cache->CachedLookup(o);
307 }
308
309 // Returns 1 if `o` is an EagerTensor.
310 // Returns 0 otherwise.
311 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)312 int IsEagerTensorHelper(PyObject* o) {
313 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
314 return IsInstanceOfRegisteredType(to_check, "EagerTensor");
315 });
316 return check_cache->CachedLookup(o);
317 }
318
319 // Returns 1 if `o` is a ResourceVariable.
320 // Returns 0 otherwise.
321 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)322 int IsResourceVariableHelper(PyObject* o) {
323 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
324 return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
325 });
326 return check_cache->CachedLookup(o);
327 }
328
329 // Returns 1 if `o` is a ResourceVariable.
330 // Returns 0 otherwise.
331 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)332 int IsVariableHelper(PyObject* o) {
333 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
334 return IsInstanceOfRegisteredType(to_check, "Variable");
335 });
336 return check_cache->CachedLookup(o);
337 }
338
339 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
340 // Returns 0 otherwise.
341 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)342 int IsSequenceHelper(PyObject* o) {
343 // We treat dicts and other mappings as special cases of sequences.
344 if (IsMappingHelper(o)) return true;
345 if (IsMappingViewHelper(o)) return true;
346 if (IsAttrsHelper(o)) return true;
347 if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
348 LOG(WARNING) << "Sets are not currently considered sequences, "
349 "but this may change in the future, "
350 "so consider avoiding using them.";
351 WarnedThatSetIsNotSequence = true;
352 }
353 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
354 int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
355
356 // Don't cache a failed is_instance check.
357 if (is_instance == -1) return -1;
358
359 return static_cast<int>(is_instance != 0 && !IsString(to_check));
360 });
361 return check_cache->CachedLookup(o);
362 }
363
364 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
365 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)366 int IsDispatchableHelper(PyObject* o) {
367 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
368 return PyObject_HasAttrString(
369 reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
370 });
371 return check_cache->CachedLookup(o);
372 }
373
374 // ValueIterator interface
375 class ValueIterator {
376 public:
~ValueIterator()377 virtual ~ValueIterator() {}
378 virtual Safe_PyObjectPtr next() = 0;
379
valid() const380 bool valid() const { return is_valid_; }
381
382 protected:
invalidate()383 void invalidate() { is_valid_ = false; }
384
385 private:
386 bool is_valid_ = true;
387 };
388
389 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
390
391 // Iterate through dictionaries in a deterministic order by sorting the
392 // keys. Notice this means that we ignore the original order of
393 // `OrderedDict` instances. This is intentional, to avoid potential
394 // bugs caused by mixing ordered and plain dicts (e.g., flattening
395 // a dict but using a corresponding `OrderedDict` to pack it back).
396 class DictValueIterator : public ValueIterator {
397 public:
DictValueIterator(PyObject * dict)398 explicit DictValueIterator(PyObject* dict)
399 : dict_(dict), keys_(PyDict_Keys(dict)) {
400 if (PyList_Sort(keys_.get()) == -1) {
401 invalidate();
402 } else {
403 iter_.reset(PyObject_GetIter(keys_.get()));
404 }
405 }
406
next()407 Safe_PyObjectPtr next() override {
408 Safe_PyObjectPtr result;
409 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
410 if (key) {
411 // PyDict_GetItem returns a borrowed reference.
412 PyObject* elem = PyDict_GetItem(dict_, key.get());
413 if (elem) {
414 Py_INCREF(elem);
415 result.reset(elem);
416 } else {
417 PyErr_SetString(PyExc_RuntimeError,
418 "Dictionary was modified during iteration over it");
419 }
420 }
421 return result;
422 }
423
424 private:
425 PyObject* dict_;
426 Safe_PyObjectPtr keys_;
427 Safe_PyObjectPtr iter_;
428 };
429
430 // Iterate over mapping objects by sorting the keys first
431 class MappingValueIterator : public ValueIterator {
432 public:
MappingValueIterator(PyObject * mapping)433 explicit MappingValueIterator(PyObject* mapping)
434 : mapping_(mapping), keys_(MappingKeys(mapping)) {
435 if (!keys_ || PyList_Sort(keys_.get()) == -1) {
436 invalidate();
437 } else {
438 iter_.reset(PyObject_GetIter(keys_.get()));
439 }
440 }
441
next()442 Safe_PyObjectPtr next() override {
443 Safe_PyObjectPtr result;
444 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
445 if (key) {
446 // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
447 PyObject* elem = PyObject_GetItem(mapping_, key.get());
448 if (elem) {
449 result.reset(elem);
450 } else {
451 PyErr_SetString(PyExc_RuntimeError,
452 "Mapping was modified during iteration over it");
453 }
454 }
455 return result;
456 }
457
458 private:
459 PyObject* mapping_;
460 Safe_PyObjectPtr keys_;
461 Safe_PyObjectPtr iter_;
462 };
463
464 // Iterate over a sequence, by index.
465 class SequenceValueIterator : public ValueIterator {
466 public:
SequenceValueIterator(PyObject * iterable)467 explicit SequenceValueIterator(PyObject* iterable)
468 : seq_(PySequence_Fast(iterable, "")),
469 size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
470 index_(0) {}
471
next()472 Safe_PyObjectPtr next() override {
473 Safe_PyObjectPtr result;
474 if (index_ < size_) {
475 // PySequence_Fast_GET_ITEM returns a borrowed reference.
476 PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
477 ++index_;
478 if (elem) {
479 Py_INCREF(elem);
480 result.reset(elem);
481 }
482 }
483
484 return result;
485 }
486
487 private:
488 Safe_PyObjectPtr seq_;
489 const Py_ssize_t size_;
490 Py_ssize_t index_;
491 };
492
493 // Iterator that just returns a single python object.
494 class SingleValueIterator : public ValueIterator {
495 public:
SingleValueIterator(PyObject * x)496 explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
497
next()498 Safe_PyObjectPtr next() override { return std::move(x_); }
499
500 private:
501 Safe_PyObjectPtr x_;
502 };
503
504 // Returns nullptr (to raise an exception) when next() is called. Caller
505 // should have already called PyErr_SetString.
506 class ErrorValueIterator : public ValueIterator {
507 public:
ErrorValueIterator()508 ErrorValueIterator() {}
next()509 Safe_PyObjectPtr next() override { return nullptr; }
510 };
511
512 class AttrsValueIterator : public ValueIterator {
513 public:
AttrsValueIterator(PyObject * nested)514 explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
515 Py_INCREF(nested);
516 cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
517 if (cls_) {
518 attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
519 if (attrs_) {
520 iter_.reset(PyObject_GetIter(attrs_.get()));
521 }
522 }
523 if (!iter_ || PyErr_Occurred()) invalidate();
524 }
525
next()526 Safe_PyObjectPtr next() override {
527 Safe_PyObjectPtr result;
528 Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
529 if (item) {
530 Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
531 result.reset(PyObject_GetAttr(nested_.get(), name.get()));
532 }
533
534 return result;
535 }
536
537 private:
538 Safe_PyObjectPtr nested_;
539 Safe_PyObjectPtr cls_;
540 Safe_PyObjectPtr attrs_;
541 Safe_PyObjectPtr iter_;
542 };
543
IsSparseTensorValueType(PyObject * o)544 bool IsSparseTensorValueType(PyObject* o) {
545 PyObject* sparse_tensor_value_type =
546 GetRegisteredPyObject("SparseTensorValue");
547 if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
548 return false;
549 }
550
551 return PyObject_TypeCheck(
552 o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
553 }
554
555 // Returns 1 if `o` is an instance of CompositeTensor.
556 // Returns 0 otherwise.
557 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)558 bool IsCompositeTensorHelper(PyObject* o) {
559 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
560 return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
561 });
562 return check_cache->CachedLookup(o);
563 }
564
565 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
566 // VariableSpec.
567 // Returns 0 otherwise.
568 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)569 bool IsTypeSpecHelper(PyObject* o) {
570 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
571 int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
572 int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
573 IsInstanceOfRegisteredType(to_check, "VariableSpec"));
574 if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
575 return static_cast<int>(is_type_spec && !is_dense_spec);
576 });
577 return check_cache->CachedLookup(o);
578 }
579
580 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
581 // (non-TensorSpec and non-VariableSpec) TypeSpec.
582 // Returns 0 otherwise.
583 // Returns -1 if an error occurred.
IsSequenceOrCompositeHelper(PyObject * o)584 int IsSequenceOrCompositeHelper(PyObject* o) {
585 int is_sequence = IsSequenceHelper(o);
586 int is_composite = IsCompositeTensorHelper(o);
587 int is_type_spec = IsTypeSpecHelper(o);
588 if ((is_sequence == -1) || (is_composite == -1) || (is_type_spec == -1)) {
589 return -1;
590 }
591 return is_sequence || is_composite || is_type_spec;
592 }
593
IsSequenceForDataHelper(PyObject * o)594 int IsSequenceForDataHelper(PyObject* o) {
595 return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
596 !IsSparseTensorValueType(o);
597 }
598
GetValueIterator(PyObject * nested)599 ValueIteratorPtr GetValueIterator(PyObject* nested) {
600 if (PyDict_Check(nested)) {
601 return absl::make_unique<DictValueIterator>(nested);
602 } else if (IsMappingHelper(nested)) {
603 return absl::make_unique<MappingValueIterator>(nested);
604 } else if (IsAttrsHelper(nested)) {
605 return absl::make_unique<AttrsValueIterator>(nested);
606 } else {
607 return absl::make_unique<SequenceValueIterator>(nested);
608 }
609 }
610
611 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)612 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
613 if (PyDict_Check(nested)) {
614 return absl::make_unique<DictValueIterator>(nested);
615 } else if (IsMappingHelper(nested)) {
616 return absl::make_unique<MappingValueIterator>(nested);
617 } else if (IsAttrsHelper(nested)) {
618 return absl::make_unique<AttrsValueIterator>(nested);
619 } else if (IsSparseTensorValueType(nested)) {
620 return absl::make_unique<SingleValueIterator>(nested);
621 } else {
622 return absl::make_unique<SequenceValueIterator>(nested);
623 }
624 }
625
626 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)627 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
628 if (IsCompositeTensor(nested)) {
629 Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
630 if (PyErr_Occurred() || !spec) {
631 return absl::make_unique<ErrorValueIterator>();
632 }
633
634 static char to_components[] = "_to_components";
635 static char argspec[] = "(O)";
636 Safe_PyObjectPtr components(
637 PyObject_CallMethod(spec.get(), to_components, argspec, nested));
638 if (PyErr_Occurred() || components == nullptr) {
639 return absl::make_unique<ErrorValueIterator>();
640 }
641 return absl::make_unique<SingleValueIterator>(components.get());
642 }
643
644 if (IsTypeSpec(nested)) {
645 Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
646 if (PyErr_Occurred() || specs == nullptr) {
647 return absl::make_unique<ErrorValueIterator>();
648 }
649 return absl::make_unique<SingleValueIterator>(specs.get());
650 }
651
652 return GetValueIterator(nested);
653 }
654
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)655 bool FlattenHelper(
656 PyObject* nested, PyObject* list,
657 const std::function<int(PyObject*)>& is_sequence_helper,
658 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
659 // if nested is not a sequence, append itself and exit
660 int is_seq = is_sequence_helper(nested);
661 if (is_seq == -1) return false;
662 if (!is_seq) {
663 return PyList_Append(list, nested) != -1;
664 }
665
666 ValueIteratorPtr iter = value_iterator_getter(nested);
667 if (!iter->valid()) return false;
668
669 for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
670 if (Py_EnterRecursiveCall(" in flatten")) {
671 return false;
672 }
673 const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
674 value_iterator_getter);
675 Py_LeaveRecursiveCall();
676 if (!success) {
677 return false;
678 }
679 }
680 return true;
681 }
682
683 // Sets error using keys of 'dict1' and 'dict2'.
684 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)685 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
686 bool* is_type_error) {
687 Safe_PyObjectPtr k1(MappingKeys(dict1));
688 if (PyErr_Occurred() || k1.get() == nullptr) {
689 *error_msg =
690 ("The two dictionaries don't have the same set of keys. Failed to "
691 "fetch keys.");
692 return;
693 }
694 Safe_PyObjectPtr k2(MappingKeys(dict2));
695 if (PyErr_Occurred() || k2.get() == nullptr) {
696 *error_msg =
697 ("The two dictionaries don't have the same set of keys. Failed to "
698 "fetch keys.");
699 return;
700 }
701 *is_type_error = false;
702 *error_msg = tensorflow::strings::StrCat(
703 "The two dictionaries don't have the same set of keys. "
704 "First structure has keys ",
705 PyObjectToString(k1.get()), ", while second structure has keys ",
706 PyObjectToString(k2.get()));
707 }
708
709 // Returns true iff there were no "internal" errors. In other words,
710 // errors that has nothing to do with structure checking.
711 // If an "internal" error occurred, the appropriate Python error will be
712 // set and the caller can propage it directly to the user.
713 //
714 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
715 // be empty.
716 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
717 // with appropriate error and sets `is_type_error` to true iff
718 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)719 bool AssertSameStructureHelper(
720 PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
721 bool* is_type_error,
722 const std::function<int(PyObject*)>& is_sequence_helper,
723 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
724 bool check_composite_tensor_type_spec) {
725 DCHECK(error_msg);
726 DCHECK(is_type_error);
727 const bool is_seq1 = is_sequence_helper(o1);
728 const bool is_seq2 = is_sequence_helper(o2);
729 if (PyErr_Occurred()) return false;
730 if (is_seq1 != is_seq2) {
731 string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
732 string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
733 *is_type_error = false;
734 *error_msg = tensorflow::strings::StrCat(
735 "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
736 non_seq_str, "\" is not");
737 return true;
738 }
739
740 // Got to objects that are considered non-sequences. Note that in tf.data
741 // use case lists and sparse_tensors are not considered sequences. So finished
742 // checking, structures are the same.
743 if (!is_seq1) return true;
744
745 if (check_types) {
746 // Treat wrapped tuples as tuples.
747 tensorflow::Safe_PyObjectPtr o1_wrapped;
748 if (IsObjectProxy(o1)) {
749 o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
750 o1 = o1_wrapped.get();
751 }
752 tensorflow::Safe_PyObjectPtr o2_wrapped;
753 if (IsObjectProxy(o2)) {
754 o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
755 o2 = o2_wrapped.get();
756 }
757
758 const PyTypeObject* type1 = o1->ob_type;
759 const PyTypeObject* type2 = o2->ob_type;
760
761 // We treat two different namedtuples with identical name and fields
762 // as having the same type.
763 const PyObject* o1_tuple = IsNamedtuple(o1, false);
764 if (o1_tuple == nullptr) return false;
765 const PyObject* o2_tuple = IsNamedtuple(o2, false);
766 if (o2_tuple == nullptr) {
767 Py_DECREF(o1_tuple);
768 return false;
769 }
770 bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
771 Py_DECREF(o1_tuple);
772 Py_DECREF(o2_tuple);
773
774 if (both_tuples) {
775 const PyObject* same_tuples = SameNamedtuples(o1, o2);
776 if (same_tuples == nullptr) return false;
777 bool not_same_tuples = same_tuples != Py_True;
778 Py_DECREF(same_tuples);
779 if (not_same_tuples) {
780 *is_type_error = true;
781 *error_msg = tensorflow::strings::StrCat(
782 "The two namedtuples don't have the same sequence type. "
783 "First structure ",
784 PyObjectToString(o1), " has type ", type1->tp_name,
785 ", while second structure ", PyObjectToString(o2), " has type ",
786 type2->tp_name);
787 return true;
788 }
789 } else if (type1 != type2
790 /* If both sequences are list types, don't complain. This allows
791 one to be a list subclass (e.g. _ListWrapper used for
792 automatic dependency tracking.) */
793 && !(PyList_Check(o1) && PyList_Check(o2))
794 /* Two mapping types will also compare equal, making _DictWrapper
795 and dict compare equal. */
796 && !(IsMappingHelper(o1) && IsMappingHelper(o2))
797 /* For CompositeTensor & TypeSpec, we check below. */
798 && !(check_composite_tensor_type_spec &&
799 (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
800 (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
801 *is_type_error = true;
802 *error_msg = tensorflow::strings::StrCat(
803 "The two namedtuples don't have the same sequence type. "
804 "First structure ",
805 PyObjectToString(o1), " has type ", type1->tp_name,
806 ", while second structure ", PyObjectToString(o2), " has type ",
807 type2->tp_name);
808 return true;
809 }
810
811 if (PyDict_Check(o1) && PyDict_Check(o2)) {
812 if (PyDict_Size(o1) != PyDict_Size(o2)) {
813 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
814 return true;
815 }
816
817 PyObject* key;
818 Py_ssize_t pos = 0;
819 while (PyDict_Next(o1, &pos, &key, nullptr)) {
820 if (PyDict_GetItem(o2, key) == nullptr) {
821 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
822 return true;
823 }
824 }
825 } else if (IsMappingHelper(o1)) {
826 // Fallback for custom mapping types. Instead of using PyDict methods
827 // which stay in C, we call iter(o1).
828 if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
829 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
830 return true;
831 }
832
833 Safe_PyObjectPtr iter(PyObject_GetIter(o1));
834 PyObject* key;
835 while ((key = PyIter_Next(iter.get())) != nullptr) {
836 if (!PyMapping_HasKey(o2, key)) {
837 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
838 Py_DECREF(key);
839 return true;
840 }
841 Py_DECREF(key);
842 }
843 }
844 }
845
846 if (check_composite_tensor_type_spec &&
847 (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
848 Safe_PyObjectPtr owned_type_spec_1;
849 PyObject* type_spec_1 = o1;
850 if (IsCompositeTensor(o1)) {
851 owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
852 type_spec_1 = owned_type_spec_1.get();
853 }
854
855 Safe_PyObjectPtr owned_type_spec_2;
856 PyObject* type_spec_2 = o2;
857 if (IsCompositeTensor(o2)) {
858 owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
859 type_spec_2 = owned_type_spec_2.get();
860 }
861
862 // Two composite tensors are considered to have the same structure if
863 // there is some type spec that is compatible with both of them. Thus,
864 // we use most_specific_compatible_type(), and check if it raises an
865 // exception. We do *not* use is_compatible_with, since that would
866 // prevent us from e.g. using a cond statement where the two sides have
867 // different shapes.
868 static char compatible_type[] = "most_specific_compatible_type";
869 static char argspec[] = "(O)";
870 Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
871 type_spec_1, compatible_type, argspec, type_spec_2));
872 if (PyErr_Occurred() || struct_compatible == nullptr) {
873 PyErr_Clear();
874 *is_type_error = false;
875 *error_msg = tensorflow::strings::StrCat(
876 "Incompatible CompositeTensor TypeSpecs: ",
877 PyObjectToString(type_spec_1), " vs. ",
878 PyObjectToString(type_spec_2));
879 return true;
880 }
881 }
882
883 ValueIteratorPtr iter1 = value_iterator_getter(o1);
884 ValueIteratorPtr iter2 = value_iterator_getter(o2);
885
886 if (!iter1->valid() || !iter2->valid()) return false;
887
888 while (true) {
889 Safe_PyObjectPtr v1 = iter1->next();
890 Safe_PyObjectPtr v2 = iter2->next();
891 if (v1 && v2) {
892 if (Py_EnterRecursiveCall(" in assert_same_structure")) {
893 return false;
894 }
895 bool no_internal_errors = AssertSameStructureHelper(
896 v1.get(), v2.get(), check_types, error_msg, is_type_error,
897 is_sequence_helper, value_iterator_getter,
898 check_composite_tensor_type_spec);
899 Py_LeaveRecursiveCall();
900 if (!no_internal_errors) return false;
901 if (!error_msg->empty()) return true;
902 } else if (!v1 && !v2) {
903 // Done with all recursive calls. Structure matched.
904 return true;
905 } else {
906 *is_type_error = false;
907 *error_msg = tensorflow::strings::StrCat(
908 "The two structures don't have the same number of elements. ",
909 "First structure: ", PyObjectToString(o1),
910 ". Second structure: ", PyObjectToString(o2));
911 return true;
912 }
913 }
914 }
915
916 } // namespace
917
IsSequence(PyObject * o)918 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)919 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)920 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)921 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)922 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)923 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)924 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)925 bool IsResourceVariable(PyObject* o) {
926 return IsResourceVariableHelper(o) == 1;
927 }
IsVariable(PyObject * o)928 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)929 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)930 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
931
IsTuple(PyObject * o)932 bool IsTuple(PyObject* o) {
933 tensorflow::Safe_PyObjectPtr wrapped;
934 if (IsObjectProxy(o)) {
935 wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
936 o = wrapped.get();
937 }
938 return PyTuple_Check(o);
939 }
940
941 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
942 // and while we're at it give them consistent behavior by making sure the
943 // returned value is a list.
944 //
945 // As with PyMapping_Keys, returns a new reference.
946 //
947 // On failure, returns nullptr.
MappingKeys(PyObject * o)948 PyObject* MappingKeys(PyObject* o) {
949 #if PY_MAJOR_VERSION >= 3
950 return PyMapping_Keys(o);
951 #else
952 static char key_method_name[] = "keys";
953 Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
954 if (PyErr_Occurred() || raw_result.get() == nullptr) {
955 return nullptr;
956 }
957 return PySequence_Fast(
958 raw_result.get(),
959 "The '.keys()' method of a custom mapping returned a non-sequence.");
960 #endif
961 }
962
Flatten(PyObject * nested,bool expand_composites)963 PyObject* Flatten(PyObject* nested, bool expand_composites) {
964 PyObject* list = PyList_New(0);
965 const std::function<int(PyObject*)>& is_sequence_helper =
966 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
967 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
968 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
969 if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
970 return list;
971 } else {
972 Py_DECREF(list);
973 return nullptr;
974 }
975 }
976
IsSequenceOrComposite(PyObject * o)977 bool IsSequenceOrComposite(PyObject* o) {
978 return IsSequenceOrCompositeHelper(o) == 1;
979 }
980
IsCompositeTensor(PyObject * o)981 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
982
IsTypeSpec(PyObject * o)983 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
984
IsSequenceForData(PyObject * o)985 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
986
FlattenForData(PyObject * nested)987 PyObject* FlattenForData(PyObject* nested) {
988 PyObject* list = PyList_New(0);
989 if (FlattenHelper(nested, list, IsSequenceForDataHelper,
990 GetValueIteratorForData)) {
991 return list;
992 } else {
993 Py_DECREF(list);
994 return nullptr;
995 }
996 }
997
IsNamedtuple(PyObject * o,bool strict)998 PyObject* IsNamedtuple(PyObject* o, bool strict) {
999 // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1000 // require some unwrapping if we want to treat them like the objects they're
1001 // wrapping.
1002 tensorflow::Safe_PyObjectPtr o_wrapped;
1003 if (IsObjectProxy(o)) {
1004 o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1005 o = o_wrapped.get();
1006 }
1007
1008 // Must be subclass of tuple
1009 if (!PyTuple_Check(o)) {
1010 Py_RETURN_FALSE;
1011 }
1012
1013 // If strict, o.__class__.__base__ must be tuple
1014 if (strict) {
1015 PyObject* klass = PyObject_GetAttrString(o, "__class__");
1016 if (klass == nullptr) return nullptr;
1017 PyObject* base = PyObject_GetAttrString(klass, "__base__");
1018 Py_DECREF(klass);
1019 if (base == nullptr) return nullptr;
1020
1021 const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1022 // built-in object types are singletons
1023 bool tuple_base = base_type == &PyTuple_Type;
1024 Py_DECREF(base);
1025 if (!tuple_base) {
1026 Py_RETURN_FALSE;
1027 }
1028 }
1029
1030 // o must have attribute '_fields' and every element in
1031 // '_fields' must be a string.
1032 int has_fields = PyObject_HasAttrString(o, "_fields");
1033 if (!has_fields) {
1034 Py_RETURN_FALSE;
1035 }
1036
1037 Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1038 int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1039 if (is_instance == 0) {
1040 Py_RETURN_FALSE;
1041 } else if (is_instance == -1) {
1042 return nullptr;
1043 }
1044
1045 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1046 const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1047 for (Py_ssize_t i = 0; i < s; ++i) {
1048 // PySequence_Fast_GET_ITEM returns borrowed ref
1049 PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1050 if (!IsString(elem)) {
1051 Py_RETURN_FALSE;
1052 }
1053 }
1054
1055 Py_RETURN_TRUE;
1056 }
1057
SameNamedtuples(PyObject * o1,PyObject * o2)1058 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1059 Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1060 Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1061 if (f1 == nullptr || f2 == nullptr) {
1062 PyErr_SetString(
1063 PyExc_RuntimeError,
1064 "Expected namedtuple-like objects (that have _fields attr)");
1065 return nullptr;
1066 }
1067
1068 if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1069 Py_RETURN_FALSE;
1070 }
1071
1072 if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1073 Py_RETURN_TRUE;
1074 } else {
1075 Py_RETURN_FALSE;
1076 }
1077 }
1078
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1079 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1080 bool expand_composites) {
1081 const std::function<int(PyObject*)>& is_sequence_helper =
1082 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
1083 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1084 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1085 const bool check_composite_tensor_type_spec = expand_composites;
1086 string error_msg;
1087 bool is_type_error = false;
1088 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1089 is_sequence_helper, get_value_iterator,
1090 check_composite_tensor_type_spec);
1091 if (PyErr_Occurred()) {
1092 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1093 // from custom mappings).
1094 return nullptr;
1095 }
1096 if (!error_msg.empty()) {
1097 PyErr_SetString(
1098 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1099 tensorflow::strings::StrCat(
1100 "The two structures don't have the same nested structure.\n\n",
1101 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1102 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1103 .c_str());
1104 return nullptr;
1105 }
1106 Py_RETURN_NONE;
1107 }
1108
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1109 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1110 bool check_types) {
1111 string error_msg;
1112 bool is_type_error = false;
1113 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1114 IsSequenceForDataHelper, GetValueIterator, false);
1115 if (PyErr_Occurred()) {
1116 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1117 // from custom mappings).
1118 return nullptr;
1119 }
1120 if (!error_msg.empty()) {
1121 PyErr_SetString(
1122 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1123 tensorflow::strings::StrCat(
1124 "The two structures don't have the same nested structure.\n\n",
1125 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1126 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1127 .c_str());
1128 return nullptr;
1129 }
1130 Py_RETURN_NONE;
1131 }
1132
1133 } // namespace swig
1134 } // namespace tensorflow
1135