• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28 
29 namespace tensorflow {
30 namespace swig {
31 
32 namespace {
33 string PyObjectToString(PyObject* o);
34 }  // namespace
35 
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37   static auto* m = new std::unordered_map<string, PyObject*>();
38   return m;
39 }
40 
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42   const auto* m = RegisteredPyObjectMap();
43   auto it = m->find(name);
44   if (it == m->end()) {
45     PyErr_SetString(PyExc_TypeError,
46                     tensorflow::strings::StrCat("No object with name ", name,
47                                                 " has been registered.")
48                         .c_str());
49     return nullptr;
50   }
51   return it->second;
52 }
53 
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55   if (!PyType_Check(type)) {
56     PyErr_SetString(PyExc_TypeError,
57                     tensorflow::strings::StrCat("Expecting a type, got ",
58                                                 Py_TYPE(type)->tp_name)
59                         .c_str());
60     return nullptr;
61   }
62   return RegisterPyObject(type_name, type);
63 }
64 
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66   string key;
67   if (PyBytes_Check(name)) {
68     key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70   } else if (PyUnicode_Check(name)) {
71     key = PyUnicode_AsUTF8(name);
72 #endif
73   } else {
74     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75                                          "Expected name to be a str, got",
76                                          PyObjectToString(name))
77                                          .c_str());
78     return nullptr;
79   }
80 
81   auto* m = RegisteredPyObjectMap();
82   if (m->find(key) != m->end()) {
83     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84                                          "Value already registered for ", key)
85                                          .c_str());
86     return nullptr;
87   }
88 
89   Py_INCREF(value);
90   m->emplace(key, value);
91 
92   Py_RETURN_NONE;
93 }
94 
95 namespace {
96 const int kMaxItemsInCache = 1024;
97 
IsString(PyObject * o)98 bool IsString(PyObject* o) {
99   return PyBytes_Check(o) ||
100 #if PY_MAJOR_VERSION < 3
101          PyString_Check(o) ||
102 #endif
103          PyUnicode_Check(o);
104 }
105 
106 // Equivalent to Python's 'o.__class__.__name__'
107 // Note that '__class__' attribute is set only in new-style classes.
108 // A lot of tensorflow code uses __class__ without checks, so it seems like
109 // we only support new-style classes.
GetClassName(PyObject * o)110 StringPiece GetClassName(PyObject* o) {
111   // __class__ is equivalent to type() for new style classes.
112   // type() is equivalent to PyObject_Type()
113   // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
114   // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
115   // we don't need here.
116   PyTypeObject* type = o->ob_type;
117 
118   // __name__ is the value of `tp_name` after the last '.'
119   // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
120   StringPiece name(type->tp_name);
121   size_t pos = name.rfind('.');
122   if (pos != StringPiece::npos) {
123     name.remove_prefix(pos + 1);
124   }
125   return name;
126 }
127 
PyObjectToString(PyObject * o)128 string PyObjectToString(PyObject* o) {
129   if (o == nullptr) {
130     return "<null object>";
131   }
132   PyObject* str = PyObject_Str(o);
133   if (str) {
134 #if PY_MAJOR_VERSION < 3
135     string s(PyString_AS_STRING(str));
136 #else
137     string s(PyUnicode_AsUTF8(str));
138 #endif
139     Py_DECREF(str);
140     return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
141   } else {
142     return "<failed to execute str() on object>";
143   }
144 }
145 
146 class CachedTypeCheck {
147  public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)148   explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
149       : ternary_predicate_(std::move(ternary_predicate)) {}
150 
~CachedTypeCheck()151   ~CachedTypeCheck() {
152     mutex_lock l(type_to_sequence_map_mu_);
153     for (const auto& pair : type_to_sequence_map_) {
154       Py_DECREF(pair.first);
155     }
156   }
157 
158   // Caches successful executions of the one-argument (PyObject*) callable
159   // "ternary_predicate" based on the type of "o". -1 from the callable
160   // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
161   // does not match the predicate, and 1 indicates that it does. Used to avoid
162   // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)163   int CachedLookup(PyObject* o) {
164     // Try not to return to Python - see if the type has already been seen
165     // before.
166 
167     auto* type = Py_TYPE(o);
168 
169     {
170       tf_shared_lock l(type_to_sequence_map_mu_);
171       auto it = type_to_sequence_map_.find(type);
172       if (it != type_to_sequence_map_.end()) {
173         return it->second;
174       }
175     }
176 
177     int check_result = ternary_predicate_(o);
178 
179     if (check_result == -1) {
180       return -1;  // Type check error, not cached.
181     }
182 
183     // NOTE: This is never decref'd as long as the object lives, which is likely
184     // forever, but we don't want the type to get deleted as long as it is in
185     // the map. This should not be too much of a leak, as there should only be a
186     // relatively small number of types in the map, and an even smaller number
187     // that are eligible for decref. As a precaution, we limit the size of the
188     // map to 1024.
189     {
190       mutex_lock l(type_to_sequence_map_mu_);
191       if (type_to_sequence_map_.size() < kMaxItemsInCache) {
192         Py_INCREF(type);
193         auto insert_result = type_to_sequence_map_.insert({type, check_result});
194         if (!insert_result.second) {
195           // The type was added to the cache by a concurrent thread after we
196           // looked it up above.
197           Py_DECREF(type);
198         }
199       }
200     }
201 
202     return check_result;
203   }
204 
205  private:
206   std::function<int(PyObject*)> ternary_predicate_;
207   mutex type_to_sequence_map_mu_;
208   std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
209       TF_GUARDED_BY(type_to_sequence_map_mu_);
210 };
211 
212 // Returns 1 if 'obj' is an instance of 'type_name'
213 // Returns 0 otherwise.
214 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)215 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
216   PyObject* type_obj = GetRegisteredPyObject(type_name);
217   if (TF_PREDICT_FALSE(type_obj == nullptr)) {
218     PyErr_SetString(PyExc_RuntimeError,
219                     tensorflow::strings::StrCat(
220                         type_name,
221                         " type has not been set. "
222                         "Please register the type with the identifier \"",
223                         type_name, "\" using RegisterType.")
224                         .c_str());
225     return -1;
226   }
227   return PyObject_IsInstance(obj, type_obj);
228 }
229 
230 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
231 // Returns 0 otherwise.
232 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)233 int IsMappingHelper(PyObject* o) {
234   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
235     return IsInstanceOfRegisteredType(to_check, "Mapping");
236   });
237   if (PyDict_Check(o)) return true;
238   return check_cache->CachedLookup(o);
239 }
240 
241 // Returns 1 if `o` is considered a mutable mapping for the purposes of
242 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)243 int IsMutableMappingHelper(PyObject* o) {
244   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
245     return IsInstanceOfRegisteredType(to_check, "MutableMapping");
246   });
247   if (PyDict_Check(o)) return true;
248   return check_cache->CachedLookup(o);
249 }
250 
251 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
252 // Returns 0 otherwise.
253 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)254 int IsMappingViewHelper(PyObject* o) {
255   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
256     return IsInstanceOfRegisteredType(to_check, "MappingView");
257   });
258   return check_cache->CachedLookup(o);
259 }
260 
261 // Returns 1 if `o` is considered an object proxy
262 // Returns 0 otherwise.
263 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)264 int IsObjectProxy(PyObject* o) {
265   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
266     return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
267   });
268   return check_cache->CachedLookup(o);
269 }
270 
271 // Returns 1 if `o` is an instance of attrs-decorated class.
272 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)273 int IsAttrsHelper(PyObject* o) {
274   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
275     Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
276     if (cls) {
277       return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
278     }
279 
280     // PyObject_GetAttrString returns null on error
281     PyErr_Clear();
282     return 0;
283   });
284   return check_cache->CachedLookup(o);
285 }
286 
287 // Returns 1 if `o` is an object of type IndexedSlices.
288 // Returns 0 otherwise.
289 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)290 int IsIndexedSlicesHelper(PyObject* o) {
291   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
292     return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
293   });
294   return check_cache->CachedLookup(o);
295 }
296 
297 // Returns 1 if `o` is a Tensor.
298 // Returns 0 otherwise.
299 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)300 int IsTensorHelper(PyObject* o) {
301   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
302     return IsInstanceOfRegisteredType(to_check, "Tensor");
303   });
304   return check_cache->CachedLookup(o);
305 }
306 
307 // Returns 1 if `o` is a TensorSpec.
308 // Returns 0 otherwise.
309 // Returns -1 if an error occurred.
IsTensorSpecHelper(PyObject * o)310 int IsTensorSpecHelper(PyObject* o) {
311   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
312     return IsInstanceOfRegisteredType(to_check, "TensorSpec");
313   });
314   return check_cache->CachedLookup(o);
315 }
316 
317 // Returns 1 if `o` is an EagerTensor.
318 // Returns 0 otherwise.
319 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)320 int IsEagerTensorHelper(PyObject* o) {
321   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
322     return IsInstanceOfRegisteredType(to_check, "EagerTensor");
323   });
324   return check_cache->CachedLookup(o);
325 }
326 
327 // Returns 1 if `o` is a ResourceVariable.
328 // Returns 0 otherwise.
329 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)330 int IsResourceVariableHelper(PyObject* o) {
331   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
332     return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
333   });
334   return check_cache->CachedLookup(o);
335 }
336 
337 // Returns 1 if `o` is a OwnedIterator.
338 // Returns 0 otherwise.
339 // Returns -1 if an error occurred.
IsOwnedIteratorHelper(PyObject * o)340 int IsOwnedIteratorHelper(PyObject* o) {
341   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
342     return IsInstanceOfRegisteredType(to_check, "OwnedIterator");
343   });
344   return check_cache->CachedLookup(o);
345 }
346 
347 // Returns 1 if `o` is a ResourceVariable.
348 // Returns 0 otherwise.
349 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)350 int IsVariableHelper(PyObject* o) {
351   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
352     return IsInstanceOfRegisteredType(to_check, "Variable");
353   });
354   return check_cache->CachedLookup(o);
355 }
356 
357 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
358 // Returns 0 otherwise.
359 // Returns -1 if an error occurred.
IsNestedHelper(PyObject * o)360 int IsNestedHelper(PyObject* o) {
361   // We treat dicts and other mappings as special cases of sequences.
362   if (IsMappingHelper(o)) return true;
363   if (IsMappingViewHelper(o)) return true;
364   if (IsAttrsHelper(o)) return true;
365 
366   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
367     int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
368 
369     // Don't cache a failed is_instance check.
370     if (is_instance == -1) return -1;
371 
372     return static_cast<int>(is_instance != 0 && !IsString(to_check));
373   });
374   return check_cache->CachedLookup(o);
375 }
376 
377 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
378 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)379 int IsDispatchableHelper(PyObject* o) {
380   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
381     return PyObject_HasAttrString(
382         reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
383   });
384   return check_cache->CachedLookup(o);
385 }
386 
387 // ValueIterator interface
388 class ValueIterator {
389  public:
~ValueIterator()390   virtual ~ValueIterator() {}
391   virtual Safe_PyObjectPtr next() = 0;
392 
valid() const393   bool valid() const { return is_valid_; }
394 
395  protected:
invalidate()396   void invalidate() { is_valid_ = false; }
397 
398  private:
399   bool is_valid_ = true;
400 };
401 
402 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
403 
404 // Iterate through dictionaries in a deterministic order by sorting the
405 // keys. Notice this means that we ignore the original order of
406 // `OrderedDict` instances. This is intentional, to avoid potential
407 // bugs caused by mixing ordered and plain dicts (e.g., flattening
408 // a dict but using a corresponding `OrderedDict` to pack it back).
409 class DictValueIterator : public ValueIterator {
410  public:
DictValueIterator(PyObject * dict)411   explicit DictValueIterator(PyObject* dict)
412       : dict_(dict), keys_(PyDict_Keys(dict)) {
413     if (PyList_Sort(keys_.get()) == -1) {
414       invalidate();
415     } else {
416       iter_.reset(PyObject_GetIter(keys_.get()));
417     }
418   }
419 
next()420   Safe_PyObjectPtr next() override {
421     Safe_PyObjectPtr result;
422     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
423     if (key) {
424       // PyDict_GetItem returns a borrowed reference.
425       PyObject* elem = PyDict_GetItem(dict_, key.get());
426       if (elem) {
427         Py_INCREF(elem);
428         result.reset(elem);
429       } else {
430         PyErr_SetString(PyExc_RuntimeError,
431                         "Dictionary was modified during iteration over it");
432       }
433     }
434     return result;
435   }
436 
437  private:
438   PyObject* dict_;
439   Safe_PyObjectPtr keys_;
440   Safe_PyObjectPtr iter_;
441 };
442 
443 // Iterate over mapping objects by sorting the keys first
444 class MappingValueIterator : public ValueIterator {
445  public:
MappingValueIterator(PyObject * mapping)446   explicit MappingValueIterator(PyObject* mapping)
447       : mapping_(mapping), keys_(MappingKeys(mapping)) {
448     if (!keys_ || PyList_Sort(keys_.get()) == -1) {
449       invalidate();
450     } else {
451       iter_.reset(PyObject_GetIter(keys_.get()));
452     }
453   }
454 
next()455   Safe_PyObjectPtr next() override {
456     Safe_PyObjectPtr result;
457     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
458     if (key) {
459       // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
460       PyObject* elem = PyObject_GetItem(mapping_, key.get());
461       if (elem) {
462         result.reset(elem);
463       } else {
464         PyErr_SetString(PyExc_RuntimeError,
465                         "Mapping was modified during iteration over it");
466       }
467     }
468     return result;
469   }
470 
471  private:
472   PyObject* mapping_;
473   Safe_PyObjectPtr keys_;
474   Safe_PyObjectPtr iter_;
475 };
476 
477 // Iterate over a sequence, by index.
478 class SequenceValueIterator : public ValueIterator {
479  public:
SequenceValueIterator(PyObject * iterable)480   explicit SequenceValueIterator(PyObject* iterable)
481       : seq_(PySequence_Fast(iterable, "")),
482         size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
483         index_(0) {}
484 
next()485   Safe_PyObjectPtr next() override {
486     Safe_PyObjectPtr result;
487     if (index_ < size_) {
488       // PySequence_Fast_GET_ITEM returns a borrowed reference.
489       PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
490       ++index_;
491       if (elem) {
492         Py_INCREF(elem);
493         result.reset(elem);
494       }
495     }
496 
497     return result;
498   }
499 
500  private:
501   Safe_PyObjectPtr seq_;
502   const Py_ssize_t size_;
503   Py_ssize_t index_;
504 };
505 
506 // Iterator that just returns a single python object.
507 class SingleValueIterator : public ValueIterator {
508  public:
SingleValueIterator(PyObject * x)509   explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
510 
next()511   Safe_PyObjectPtr next() override { return std::move(x_); }
512 
513  private:
514   Safe_PyObjectPtr x_;
515 };
516 
517 // Returns nullptr (to raise an exception) when next() is called.  Caller
518 // should have already called PyErr_SetString.
519 class ErrorValueIterator : public ValueIterator {
520  public:
ErrorValueIterator()521   ErrorValueIterator() {}
next()522   Safe_PyObjectPtr next() override { return nullptr; }
523 };
524 
525 class AttrsValueIterator : public ValueIterator {
526  public:
AttrsValueIterator(PyObject * nested)527   explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
528     Py_INCREF(nested);
529     cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
530     if (cls_) {
531       attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
532       if (attrs_) {
533         iter_.reset(PyObject_GetIter(attrs_.get()));
534       }
535     }
536     if (!iter_ || PyErr_Occurred()) invalidate();
537   }
538 
next()539   Safe_PyObjectPtr next() override {
540     Safe_PyObjectPtr result;
541     Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
542     if (item) {
543       Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
544       result.reset(PyObject_GetAttr(nested_.get(), name.get()));
545     }
546 
547     return result;
548   }
549 
550  private:
551   Safe_PyObjectPtr nested_;
552   Safe_PyObjectPtr cls_;
553   Safe_PyObjectPtr attrs_;
554   Safe_PyObjectPtr iter_;
555 };
556 
IsSparseTensorValueType(PyObject * o)557 bool IsSparseTensorValueType(PyObject* o) {
558   PyObject* sparse_tensor_value_type =
559       GetRegisteredPyObject("SparseTensorValue");
560   if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
561     return false;
562   }
563 
564   return PyObject_TypeCheck(
565              o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
566 }
567 
568 // Returns 1 if `o` is an instance of CompositeTensor.
569 // Returns 0 otherwise.
570 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)571 bool IsCompositeTensorHelper(PyObject* o) {
572   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
573     return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
574   });
575   return check_cache->CachedLookup(o);
576 }
577 
578 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
579 // VariableSpec.
580 // Returns 0 otherwise.
581 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)582 bool IsTypeSpecHelper(PyObject* o) {
583   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
584     int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
585     int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
586                          IsInstanceOfRegisteredType(to_check, "VariableSpec"));
587     if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
588     return static_cast<int>(is_type_spec && !is_dense_spec);
589   });
590   return check_cache->CachedLookup(o);
591 }
592 
593 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
594 // (non-TensorSpec and non-VariableSpec) TypeSpec.
595 // Returns 0 otherwise.
596 // Returns -1 if an error occurred.
IsNestedOrCompositeHelper(PyObject * o)597 int IsNestedOrCompositeHelper(PyObject* o) {
598   int is_nested = IsNestedHelper(o);
599   int is_composite = IsCompositeTensorHelper(o);
600   int is_type_spec = IsTypeSpecHelper(o);
601   if ((is_nested == -1) || (is_composite == -1) || (is_type_spec == -1)) {
602     return -1;
603   }
604   return is_nested || is_composite || is_type_spec;
605 }
606 
IsNestedForDataHelper(PyObject * o)607 int IsNestedForDataHelper(PyObject* o) {
608   return IsNestedHelper(o) == 1 && !PyList_Check(o) &&
609          !IsSparseTensorValueType(o);
610 }
611 
GetValueIterator(PyObject * nested)612 ValueIteratorPtr GetValueIterator(PyObject* nested) {
613   if (PyDict_Check(nested)) {
614     return absl::make_unique<DictValueIterator>(nested);
615   } else if (IsMappingHelper(nested)) {
616     return absl::make_unique<MappingValueIterator>(nested);
617   } else if (IsAttrsHelper(nested)) {
618     return absl::make_unique<AttrsValueIterator>(nested);
619   } else {
620     return absl::make_unique<SequenceValueIterator>(nested);
621   }
622 }
623 
624 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)625 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
626   if (PyDict_Check(nested)) {
627     return absl::make_unique<DictValueIterator>(nested);
628   } else if (IsMappingHelper(nested)) {
629     return absl::make_unique<MappingValueIterator>(nested);
630   } else if (IsAttrsHelper(nested)) {
631     return absl::make_unique<AttrsValueIterator>(nested);
632   } else if (IsSparseTensorValueType(nested)) {
633     return absl::make_unique<SingleValueIterator>(nested);
634   } else {
635     return absl::make_unique<SequenceValueIterator>(nested);
636   }
637 }
638 
639 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)640 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
641   if (IsCompositeTensor(nested)) {
642     Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
643     if (PyErr_Occurred() || !spec) {
644       return absl::make_unique<ErrorValueIterator>();
645     }
646 
647     static char to_components[] = "_to_components";
648     static char argspec[] = "(O)";
649     Safe_PyObjectPtr components(
650         PyObject_CallMethod(spec.get(), to_components, argspec, nested));
651     if (PyErr_Occurred() || components == nullptr) {
652       return absl::make_unique<ErrorValueIterator>();
653     }
654     return absl::make_unique<SingleValueIterator>(components.get());
655   }
656 
657   if (IsTypeSpec(nested)) {
658     Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
659     if (PyErr_Occurred() || specs == nullptr) {
660       return absl::make_unique<ErrorValueIterator>();
661     }
662     return absl::make_unique<SingleValueIterator>(specs.get());
663   }
664 
665   return GetValueIterator(nested);
666 }
667 
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_nested_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)668 bool FlattenHelper(
669     PyObject* nested, PyObject* list,
670     const std::function<int(PyObject*)>& is_nested_helper,
671     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
672   // if nested is not a sequence, append itself and exit
673   int is_nested = is_nested_helper(nested);
674   if (is_nested == -1) return false;
675   if (!is_nested) {
676     return PyList_Append(list, nested) != -1;
677   }
678 
679   ValueIteratorPtr iter = value_iterator_getter(nested);
680   if (!iter->valid()) return false;
681 
682   for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
683     if (Py_EnterRecursiveCall(" in flatten")) {
684       return false;
685     }
686     const bool success = FlattenHelper(item.get(), list, is_nested_helper,
687                                        value_iterator_getter);
688     Py_LeaveRecursiveCall();
689     if (!success) {
690       return false;
691     }
692   }
693   return true;
694 }
695 
696 // Sets error using keys of 'dict1' and 'dict2'.
697 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)698 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
699                            bool* is_type_error) {
700   Safe_PyObjectPtr k1(MappingKeys(dict1));
701   if (PyErr_Occurred() || k1.get() == nullptr) {
702     *error_msg =
703         ("The two dictionaries don't have the same set of keys. Failed to "
704          "fetch keys.");
705     return;
706   }
707   Safe_PyObjectPtr k2(MappingKeys(dict2));
708   if (PyErr_Occurred() || k2.get() == nullptr) {
709     *error_msg =
710         ("The two dictionaries don't have the same set of keys. Failed to "
711          "fetch keys.");
712     return;
713   }
714   *is_type_error = false;
715   *error_msg = tensorflow::strings::StrCat(
716       "The two dictionaries don't have the same set of keys. "
717       "First structure has keys ",
718       PyObjectToString(k1.get()), ", while second structure has keys ",
719       PyObjectToString(k2.get()));
720 }
721 
722 // Returns true iff there were no "internal" errors. In other words,
723 // errors that has nothing to do with structure checking.
724 // If an "internal" error occurred, the appropriate Python error will be
725 // set and the caller can propage it directly to the user.
726 //
727 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
728 // be empty.
729 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
730 // with appropriate error and sets `is_type_error` to true iff
731 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_nested_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)732 bool AssertSameStructureHelper(
733     PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
734     bool* is_type_error, const std::function<int(PyObject*)>& is_nested_helper,
735     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
736     bool check_composite_tensor_type_spec) {
737   DCHECK(error_msg);
738   DCHECK(is_type_error);
739   const bool is_nested1 = is_nested_helper(o1);
740   const bool is_nested2 = is_nested_helper(o2);
741   if (PyErr_Occurred()) return false;
742   if (is_nested1 != is_nested2) {
743     string seq_str = is_nested1 ? PyObjectToString(o1) : PyObjectToString(o2);
744     string non_seq_str =
745         is_nested1 ? PyObjectToString(o2) : PyObjectToString(o1);
746     *is_type_error = false;
747     *error_msg = tensorflow::strings::StrCat(
748         "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
749         non_seq_str, "\" is not");
750     return true;
751   }
752 
753   // Got to objects that are considered non-sequences. Note that in tf.data
754   // use case lists and sparse_tensors are not considered sequences. So finished
755   // checking, structures are the same.
756   if (!is_nested1) return true;
757 
758   if (check_types) {
759     // Treat wrapped tuples as tuples.
760     tensorflow::Safe_PyObjectPtr o1_wrapped;
761     if (IsObjectProxy(o1)) {
762       o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
763       o1 = o1_wrapped.get();
764     }
765     tensorflow::Safe_PyObjectPtr o2_wrapped;
766     if (IsObjectProxy(o2)) {
767       o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
768       o2 = o2_wrapped.get();
769     }
770 
771     const PyTypeObject* type1 = o1->ob_type;
772     const PyTypeObject* type2 = o2->ob_type;
773 
774     // We treat two different namedtuples with identical name and fields
775     // as having the same type.
776     const PyObject* o1_tuple = IsNamedtuple(o1, false);
777     if (o1_tuple == nullptr) return false;
778     const PyObject* o2_tuple = IsNamedtuple(o2, false);
779     if (o2_tuple == nullptr) {
780       Py_DECREF(o1_tuple);
781       return false;
782     }
783     bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
784     Py_DECREF(o1_tuple);
785     Py_DECREF(o2_tuple);
786 
787     if (both_tuples) {
788       const PyObject* same_tuples = SameNamedtuples(o1, o2);
789       if (same_tuples == nullptr) return false;
790       bool not_same_tuples = same_tuples != Py_True;
791       Py_DECREF(same_tuples);
792       if (not_same_tuples) {
793         *is_type_error = true;
794         *error_msg = tensorflow::strings::StrCat(
795             "The two namedtuples don't have the same sequence type. "
796             "First structure ",
797             PyObjectToString(o1), " has type ", type1->tp_name,
798             ", while second structure ", PyObjectToString(o2), " has type ",
799             type2->tp_name);
800         return true;
801       }
802     } else if (type1 != type2
803                /* If both sequences are list types, don't complain. This allows
804                   one to be a list subclass (e.g. _ListWrapper used for
805                   automatic dependency tracking.) */
806                && !(PyList_Check(o1) && PyList_Check(o2))
807                /* Two mapping types will also compare equal, making _DictWrapper
808                   and dict compare equal. */
809                && !(IsMappingHelper(o1) && IsMappingHelper(o2))
810                /* For CompositeTensor & TypeSpec, we check below. */
811                && !(check_composite_tensor_type_spec &&
812                     (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
813                     (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
814       *is_type_error = true;
815       *error_msg = tensorflow::strings::StrCat(
816           "The two namedtuples don't have the same sequence type. "
817           "First structure ",
818           PyObjectToString(o1), " has type ", type1->tp_name,
819           ", while second structure ", PyObjectToString(o2), " has type ",
820           type2->tp_name);
821       return true;
822     }
823 
824     if (PyDict_Check(o1) && PyDict_Check(o2)) {
825       if (PyDict_Size(o1) != PyDict_Size(o2)) {
826         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
827         return true;
828       }
829 
830       PyObject* key;
831       Py_ssize_t pos = 0;
832       while (PyDict_Next(o1, &pos, &key, nullptr)) {
833         if (PyDict_GetItem(o2, key) == nullptr) {
834           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
835           return true;
836         }
837       }
838     } else if (IsMappingHelper(o1)) {
839       // Fallback for custom mapping types. Instead of using PyDict methods
840       // which stay in C, we call iter(o1).
841       if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
842         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
843         return true;
844       }
845 
846       Safe_PyObjectPtr iter(PyObject_GetIter(o1));
847       PyObject* key;
848       while ((key = PyIter_Next(iter.get())) != nullptr) {
849         if (!PyMapping_HasKey(o2, key)) {
850           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
851           Py_DECREF(key);
852           return true;
853         }
854         Py_DECREF(key);
855       }
856     }
857   }
858 
859   if (check_composite_tensor_type_spec &&
860       (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
861     Safe_PyObjectPtr owned_type_spec_1;
862     PyObject* type_spec_1 = o1;
863     if (IsCompositeTensor(o1)) {
864       owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
865       type_spec_1 = owned_type_spec_1.get();
866     }
867 
868     Safe_PyObjectPtr owned_type_spec_2;
869     PyObject* type_spec_2 = o2;
870     if (IsCompositeTensor(o2)) {
871       owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
872       type_spec_2 = owned_type_spec_2.get();
873     }
874 
875     // Two composite tensors are considered to have the same structure if
876     // they share a type spec that is a supertype of both of them. We do *not*
877     // use is_subtype_of, since that would prevent us from e.g. using a
878     // cond statement where the two sides have different shapes.
879 
880     // TODO(b/206014848): We have to explicitly remove the names.
881     Safe_PyObjectPtr owned_nameless_type_spec_1(
882         PyObject_CallMethod(type_spec_1, "_without_tensor_names", nullptr));
883     Safe_PyObjectPtr owned_nameless_type_spec_2(
884         PyObject_CallMethod(type_spec_2, "_without_tensor_names", nullptr));
885     // TODO(b/222123181): Reconsider most_specific_common_supertype usage.
886     static char compatible_type[] = "most_specific_common_supertype";
887     static char argspec[] = "([O])";
888     Safe_PyObjectPtr struct_compatible(
889         PyObject_CallMethod(owned_nameless_type_spec_1.get(), compatible_type,
890                             argspec, owned_nameless_type_spec_2.get()));
891     if (PyErr_Occurred()) {
892       return false;
893     }
894     if (struct_compatible.get() == Py_None) {
895       *is_type_error = false;
896       *error_msg = tensorflow::strings::StrCat(
897           "Incompatible CompositeTensor TypeSpecs: ",
898           PyObjectToString(type_spec_1), " vs. ",
899           PyObjectToString(type_spec_2));
900       return true;
901     }
902   }
903 
904   ValueIteratorPtr iter1 = value_iterator_getter(o1);
905   ValueIteratorPtr iter2 = value_iterator_getter(o2);
906 
907   if (!iter1->valid() || !iter2->valid()) return false;
908 
909   while (true) {
910     Safe_PyObjectPtr v1 = iter1->next();
911     Safe_PyObjectPtr v2 = iter2->next();
912     if (v1 && v2) {
913       if (Py_EnterRecursiveCall(" in assert_same_structure")) {
914         return false;
915       }
916       bool no_internal_errors = AssertSameStructureHelper(
917           v1.get(), v2.get(), check_types, error_msg, is_type_error,
918           is_nested_helper, value_iterator_getter,
919           check_composite_tensor_type_spec);
920       Py_LeaveRecursiveCall();
921       if (!no_internal_errors) return false;
922       if (!error_msg->empty()) return true;
923     } else if (!v1 && !v2) {
924       // Done with all recursive calls. Structure matched.
925       return true;
926     } else {
927       *is_type_error = false;
928       *error_msg = tensorflow::strings::StrCat(
929           "The two structures don't have the same number of elements. ",
930           "First structure: ", PyObjectToString(o1),
931           ". Second structure: ", PyObjectToString(o2));
932       return true;
933     }
934   }
935 }
936 
937 }  // namespace
938 
IsNested(PyObject * o)939 bool IsNested(PyObject* o) { return IsNestedHelper(o) == 1; }
IsMapping(PyObject * o)940 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)941 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)942 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)943 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)944 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsTensorSpec(PyObject * o)945 bool IsTensorSpec(PyObject* o) { return IsTensorSpecHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)946 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)947 bool IsResourceVariable(PyObject* o) {
948   return IsResourceVariableHelper(o) == 1;
949 }
IsOwnedIterator(PyObject * o)950 bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; }
IsVariable(PyObject * o)951 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)952 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)953 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
954 
IsTuple(PyObject * o)955 bool IsTuple(PyObject* o) {
956   tensorflow::Safe_PyObjectPtr wrapped;
957   if (IsObjectProxy(o)) {
958     wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
959     o = wrapped.get();
960   }
961   return PyTuple_Check(o);
962 }
963 
964 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
965 // and while we're at it give them consistent behavior by making sure the
966 // returned value is a list.
967 //
968 // As with PyMapping_Keys, returns a new reference.
969 //
970 // On failure, returns nullptr.
MappingKeys(PyObject * o)971 PyObject* MappingKeys(PyObject* o) {
972 #if PY_MAJOR_VERSION >= 3
973   return PyMapping_Keys(o);
974 #else
975   static char key_method_name[] = "keys";
976   Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
977   if (PyErr_Occurred() || raw_result.get() == nullptr) {
978     return nullptr;
979   }
980   return PySequence_Fast(
981       raw_result.get(),
982       "The '.keys()' method of a custom mapping returned a non-sequence.");
983 #endif
984 }
985 
Flatten(PyObject * nested,bool expand_composites)986 PyObject* Flatten(PyObject* nested, bool expand_composites) {
987   PyObject* list = PyList_New(0);
988   const std::function<int(PyObject*)>& is_nested_helper =
989       expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
990   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
991       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
992   if (FlattenHelper(nested, list, is_nested_helper, get_value_iterator)) {
993     return list;
994   } else {
995     Py_DECREF(list);
996     return nullptr;
997   }
998 }
999 
IsNestedOrComposite(PyObject * o)1000 bool IsNestedOrComposite(PyObject* o) {
1001   return IsNestedOrCompositeHelper(o) == 1;
1002 }
1003 
IsCompositeTensor(PyObject * o)1004 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
1005 
IsTypeSpec(PyObject * o)1006 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
1007 
IsNestedForData(PyObject * o)1008 bool IsNestedForData(PyObject* o) { return IsNestedForDataHelper(o) == 1; }
1009 
FlattenForData(PyObject * nested)1010 PyObject* FlattenForData(PyObject* nested) {
1011   PyObject* list = PyList_New(0);
1012   if (FlattenHelper(nested, list, IsNestedForDataHelper,
1013                     GetValueIteratorForData)) {
1014     return list;
1015   } else {
1016     Py_DECREF(list);
1017     return nullptr;
1018   }
1019 }
1020 
IsNamedtuple(PyObject * o,bool strict)1021 PyObject* IsNamedtuple(PyObject* o, bool strict) {
1022   // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1023   // require some unwrapping if we want to treat them like the objects they're
1024   // wrapping.
1025   tensorflow::Safe_PyObjectPtr o_wrapped;
1026   if (IsObjectProxy(o)) {
1027     o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1028     o = o_wrapped.get();
1029   }
1030 
1031   // Must be subclass of tuple
1032   if (!PyTuple_Check(o)) {
1033     Py_RETURN_FALSE;
1034   }
1035 
1036   // If strict, o.__class__.__base__ must be tuple
1037   if (strict) {
1038     PyObject* klass = PyObject_GetAttrString(o, "__class__");
1039     if (klass == nullptr) return nullptr;
1040     PyObject* base = PyObject_GetAttrString(klass, "__base__");
1041     Py_DECREF(klass);
1042     if (base == nullptr) return nullptr;
1043 
1044     const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1045     // built-in object types are singletons
1046     bool tuple_base = base_type == &PyTuple_Type;
1047     Py_DECREF(base);
1048     if (!tuple_base) {
1049       Py_RETURN_FALSE;
1050     }
1051   }
1052 
1053   // o must have attribute '_fields' and every element in
1054   // '_fields' must be a string.
1055   int has_fields = PyObject_HasAttrString(o, "_fields");
1056   if (!has_fields) {
1057     Py_RETURN_FALSE;
1058   }
1059 
1060   Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1061   int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1062   if (is_instance == 0) {
1063     Py_RETURN_FALSE;
1064   } else if (is_instance == -1) {
1065     return nullptr;
1066   }
1067 
1068   Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1069   const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1070   for (Py_ssize_t i = 0; i < s; ++i) {
1071     // PySequence_Fast_GET_ITEM returns borrowed ref
1072     PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1073     if (!IsString(elem)) {
1074       Py_RETURN_FALSE;
1075     }
1076   }
1077 
1078   Py_RETURN_TRUE;
1079 }
1080 
SameNamedtuples(PyObject * o1,PyObject * o2)1081 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1082   Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1083   Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1084   if (f1 == nullptr || f2 == nullptr) {
1085     PyErr_SetString(
1086         PyExc_RuntimeError,
1087         "Expected namedtuple-like objects (that have _fields attr)");
1088     return nullptr;
1089   }
1090 
1091   if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1092     Py_RETURN_FALSE;
1093   }
1094 
1095   if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1096     Py_RETURN_TRUE;
1097   } else {
1098     Py_RETURN_FALSE;
1099   }
1100 }
1101 
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1102 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1103                               bool expand_composites) {
1104   const std::function<int(PyObject*)>& is_nested_helper =
1105       expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
1106   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1107       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1108   const bool check_composite_tensor_type_spec = expand_composites;
1109   string error_msg;
1110   bool is_type_error = false;
1111   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1112                             is_nested_helper, get_value_iterator,
1113                             check_composite_tensor_type_spec);
1114   if (PyErr_Occurred()) {
1115     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1116     // from custom mappings).
1117     return nullptr;
1118   }
1119   if (!error_msg.empty()) {
1120     PyErr_SetString(
1121         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1122         tensorflow::strings::StrCat(
1123             "The two structures don't have the same nested structure.\n\n",
1124             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1125             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1126             .c_str());
1127     return nullptr;
1128   }
1129   Py_RETURN_NONE;
1130 }
1131 
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1132 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1133                                      bool check_types) {
1134   string error_msg;
1135   bool is_type_error = false;
1136   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1137                             IsNestedForDataHelper, GetValueIterator, false);
1138   if (PyErr_Occurred()) {
1139     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1140     // from custom mappings).
1141     return nullptr;
1142   }
1143   if (!error_msg.empty()) {
1144     PyErr_SetString(
1145         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1146         tensorflow::strings::StrCat(
1147             "The two structures don't have the same nested structure.\n\n",
1148             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1149             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1150             .c_str());
1151     return nullptr;
1152   }
1153   Py_RETURN_NONE;
1154 }
1155 
1156 }  // namespace swig
1157 }  // namespace tensorflow
1158