• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16 
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28 
29 namespace tensorflow {
30 namespace swig {
31 
32 namespace {
33 string PyObjectToString(PyObject* o);
34 }  // namespace
35 
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37   static auto* m = new std::unordered_map<string, PyObject*>();
38   return m;
39 }
40 
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42   const auto* m = RegisteredPyObjectMap();
43   auto it = m->find(name);
44   if (it == m->end()) {
45     PyErr_SetString(PyExc_TypeError,
46                     tensorflow::strings::StrCat("No object with name ", name,
47                                                 " has been registered.")
48                         .c_str());
49     return nullptr;
50   }
51   return it->second;
52 }
53 
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55   if (!PyType_Check(type)) {
56     PyErr_SetString(PyExc_TypeError,
57                     tensorflow::strings::StrCat("Expecting a type, got ",
58                                                 Py_TYPE(type)->tp_name)
59                         .c_str());
60     return nullptr;
61   }
62   return RegisterPyObject(type_name, type);
63 }
64 
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66   string key;
67   if (PyBytes_Check(name)) {
68     key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70   } else if (PyUnicode_Check(name)) {
71     key = PyUnicode_AsUTF8(name);
72 #endif
73   } else {
74     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75                                          "Expected name to be a str, got",
76                                          PyObjectToString(name))
77                                          .c_str());
78     return nullptr;
79   }
80 
81   auto* m = RegisteredPyObjectMap();
82   if (m->find(key) != m->end()) {
83     PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84                                          "Value already registered for ", key)
85                                          .c_str());
86     return nullptr;
87   }
88 
89   Py_INCREF(value);
90   m->emplace(key, value);
91 
92   Py_RETURN_NONE;
93 }
94 
95 namespace {
96 const int kMaxItemsInCache = 1024;
97 
98 bool WarnedThatSetIsNotSequence = false;
99 
IsString(PyObject * o)100 bool IsString(PyObject* o) {
101   return PyBytes_Check(o) ||
102 #if PY_MAJOR_VERSION < 3
103          PyString_Check(o) ||
104 #endif
105          PyUnicode_Check(o);
106 }
107 
108 // Equivalent to Python's 'o.__class__.__name__'
109 // Note that '__class__' attribute is set only in new-style classes.
110 // A lot of tensorflow code uses __class__ without checks, so it seems like
111 // we only support new-style classes.
GetClassName(PyObject * o)112 StringPiece GetClassName(PyObject* o) {
113   // __class__ is equivalent to type() for new style classes.
114   // type() is equivalent to PyObject_Type()
115   // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
116   // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
117   // we don't need here.
118   PyTypeObject* type = o->ob_type;
119 
120   // __name__ is the value of `tp_name` after the last '.'
121   // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
122   StringPiece name(type->tp_name);
123   size_t pos = name.rfind('.');
124   if (pos != StringPiece::npos) {
125     name.remove_prefix(pos + 1);
126   }
127   return name;
128 }
129 
PyObjectToString(PyObject * o)130 string PyObjectToString(PyObject* o) {
131   if (o == nullptr) {
132     return "<null object>";
133   }
134   PyObject* str = PyObject_Str(o);
135   if (str) {
136 #if PY_MAJOR_VERSION < 3
137     string s(PyString_AS_STRING(str));
138 #else
139     string s(PyUnicode_AsUTF8(str));
140 #endif
141     Py_DECREF(str);
142     return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
143   } else {
144     return "<failed to execute str() on object>";
145   }
146 }
147 
148 class CachedTypeCheck {
149  public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)150   explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
151       : ternary_predicate_(std::move(ternary_predicate)) {}
152 
~CachedTypeCheck()153   ~CachedTypeCheck() {
154     mutex_lock l(type_to_sequence_map_mu_);
155     for (const auto& pair : type_to_sequence_map_) {
156       Py_DECREF(pair.first);
157     }
158   }
159 
160   // Caches successful executions of the one-argument (PyObject*) callable
161   // "ternary_predicate" based on the type of "o". -1 from the callable
162   // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
163   // does not match the predicate, and 1 indicates that it does. Used to avoid
164   // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)165   int CachedLookup(PyObject* o) {
166     // Try not to return to Python - see if the type has already been seen
167     // before.
168 
169     auto* type = Py_TYPE(o);
170 
171     {
172       tf_shared_lock l(type_to_sequence_map_mu_);
173       auto it = type_to_sequence_map_.find(type);
174       if (it != type_to_sequence_map_.end()) {
175         return it->second;
176       }
177     }
178 
179     int check_result = ternary_predicate_(o);
180 
181     if (check_result == -1) {
182       return -1;  // Type check error, not cached.
183     }
184 
185     // NOTE: This is never decref'd as long as the object lives, which is likely
186     // forever, but we don't want the type to get deleted as long as it is in
187     // the map. This should not be too much of a leak, as there should only be a
188     // relatively small number of types in the map, and an even smaller number
189     // that are eligible for decref. As a precaution, we limit the size of the
190     // map to 1024.
191     {
192       mutex_lock l(type_to_sequence_map_mu_);
193       if (type_to_sequence_map_.size() < kMaxItemsInCache) {
194         Py_INCREF(type);
195         auto insert_result = type_to_sequence_map_.insert({type, check_result});
196         if (!insert_result.second) {
197           // The type was added to the cache by a concurrent thread after we
198           // looked it up above.
199           Py_DECREF(type);
200         }
201       }
202     }
203 
204     return check_result;
205   }
206 
207  private:
208   std::function<int(PyObject*)> ternary_predicate_;
209   mutex type_to_sequence_map_mu_;
210   std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
211       TF_GUARDED_BY(type_to_sequence_map_mu_);
212 };
213 
214 // Returns 1 if 'obj' is an instance of 'type_name'
215 // Returns 0 otherwise.
216 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)217 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
218   PyObject* type_obj = GetRegisteredPyObject(type_name);
219   if (TF_PREDICT_FALSE(type_obj == nullptr)) {
220     PyErr_SetString(PyExc_RuntimeError,
221                     tensorflow::strings::StrCat(
222                         type_name,
223                         " type has not been set. "
224                         "Please register the type with the identifier \"",
225                         type_name, "\" using RegisterType.")
226                         .c_str());
227     return -1;
228   }
229   return PyObject_IsInstance(obj, type_obj);
230 }
231 
232 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
233 // Returns 0 otherwise.
234 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)235 int IsMappingHelper(PyObject* o) {
236   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
237     return IsInstanceOfRegisteredType(to_check, "Mapping");
238   });
239   if (PyDict_Check(o)) return true;
240   return check_cache->CachedLookup(o);
241 }
242 
243 // Returns 1 if `o` is considered a mutable mapping for the purposes of
244 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)245 int IsMutableMappingHelper(PyObject* o) {
246   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
247     return IsInstanceOfRegisteredType(to_check, "MutableMapping");
248   });
249   if (PyDict_Check(o)) return true;
250   return check_cache->CachedLookup(o);
251 }
252 
253 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
254 // Returns 0 otherwise.
255 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)256 int IsMappingViewHelper(PyObject* o) {
257   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
258     return IsInstanceOfRegisteredType(to_check, "MappingView");
259   });
260   return check_cache->CachedLookup(o);
261 }
262 
263 // Returns 1 if `o` is considered an object proxy
264 // Returns 0 otherwise.
265 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)266 int IsObjectProxy(PyObject* o) {
267   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
268     return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
269   });
270   return check_cache->CachedLookup(o);
271 }
272 
273 // Returns 1 if `o` is an instance of attrs-decorated class.
274 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)275 int IsAttrsHelper(PyObject* o) {
276   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
277     Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
278     if (cls) {
279       return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
280     }
281 
282     // PyObject_GetAttrString returns null on error
283     PyErr_Clear();
284     return 0;
285   });
286   return check_cache->CachedLookup(o);
287 }
288 
289 // Returns 1 if `o` is an object of type IndexedSlices.
290 // Returns 0 otherwise.
291 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)292 int IsIndexedSlicesHelper(PyObject* o) {
293   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
294     return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
295   });
296   return check_cache->CachedLookup(o);
297 }
298 
299 // Returns 1 if `o` is a Tensor.
300 // Returns 0 otherwise.
301 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)302 int IsTensorHelper(PyObject* o) {
303   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
304     return IsInstanceOfRegisteredType(to_check, "Tensor");
305   });
306   return check_cache->CachedLookup(o);
307 }
308 
309 // Returns 1 if `o` is an EagerTensor.
310 // Returns 0 otherwise.
311 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)312 int IsEagerTensorHelper(PyObject* o) {
313   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
314     return IsInstanceOfRegisteredType(to_check, "EagerTensor");
315   });
316   return check_cache->CachedLookup(o);
317 }
318 
319 // Returns 1 if `o` is a ResourceVariable.
320 // Returns 0 otherwise.
321 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)322 int IsResourceVariableHelper(PyObject* o) {
323   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
324     return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
325   });
326   return check_cache->CachedLookup(o);
327 }
328 
329 // Returns 1 if `o` is a OwnedIterator.
330 // Returns 0 otherwise.
331 // Returns -1 if an error occurred.
IsOwnedIteratorHelper(PyObject * o)332 int IsOwnedIteratorHelper(PyObject* o) {
333   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
334     return IsInstanceOfRegisteredType(to_check, "OwnedIterator");
335   });
336   return check_cache->CachedLookup(o);
337 }
338 
339 // Returns 1 if `o` is a ResourceVariable.
340 // Returns 0 otherwise.
341 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)342 int IsVariableHelper(PyObject* o) {
343   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
344     return IsInstanceOfRegisteredType(to_check, "Variable");
345   });
346   return check_cache->CachedLookup(o);
347 }
348 
349 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
350 // Returns 0 otherwise.
351 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)352 int IsSequenceHelper(PyObject* o) {
353   // We treat dicts and other mappings as special cases of sequences.
354   if (IsMappingHelper(o)) return true;
355   if (IsMappingViewHelper(o)) return true;
356   if (IsAttrsHelper(o)) return true;
357   if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
358     LOG(WARNING) << "Sets are not currently considered sequences, "
359                     "but this may change in the future, "
360                     "so consider avoiding using them.";
361     WarnedThatSetIsNotSequence = true;
362   }
363   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
364     int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
365 
366     // Don't cache a failed is_instance check.
367     if (is_instance == -1) return -1;
368 
369     return static_cast<int>(is_instance != 0 && !IsString(to_check));
370   });
371   return check_cache->CachedLookup(o);
372 }
373 
374 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
375 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)376 int IsDispatchableHelper(PyObject* o) {
377   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
378     return PyObject_HasAttrString(
379         reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
380   });
381   return check_cache->CachedLookup(o);
382 }
383 
384 // ValueIterator interface
385 class ValueIterator {
386  public:
~ValueIterator()387   virtual ~ValueIterator() {}
388   virtual Safe_PyObjectPtr next() = 0;
389 
valid() const390   bool valid() const { return is_valid_; }
391 
392  protected:
invalidate()393   void invalidate() { is_valid_ = false; }
394 
395  private:
396   bool is_valid_ = true;
397 };
398 
399 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
400 
401 // Iterate through dictionaries in a deterministic order by sorting the
402 // keys. Notice this means that we ignore the original order of
403 // `OrderedDict` instances. This is intentional, to avoid potential
404 // bugs caused by mixing ordered and plain dicts (e.g., flattening
405 // a dict but using a corresponding `OrderedDict` to pack it back).
406 class DictValueIterator : public ValueIterator {
407  public:
DictValueIterator(PyObject * dict)408   explicit DictValueIterator(PyObject* dict)
409       : dict_(dict), keys_(PyDict_Keys(dict)) {
410     if (PyList_Sort(keys_.get()) == -1) {
411       invalidate();
412     } else {
413       iter_.reset(PyObject_GetIter(keys_.get()));
414     }
415   }
416 
next()417   Safe_PyObjectPtr next() override {
418     Safe_PyObjectPtr result;
419     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
420     if (key) {
421       // PyDict_GetItem returns a borrowed reference.
422       PyObject* elem = PyDict_GetItem(dict_, key.get());
423       if (elem) {
424         Py_INCREF(elem);
425         result.reset(elem);
426       } else {
427         PyErr_SetString(PyExc_RuntimeError,
428                         "Dictionary was modified during iteration over it");
429       }
430     }
431     return result;
432   }
433 
434  private:
435   PyObject* dict_;
436   Safe_PyObjectPtr keys_;
437   Safe_PyObjectPtr iter_;
438 };
439 
440 // Iterate over mapping objects by sorting the keys first
441 class MappingValueIterator : public ValueIterator {
442  public:
MappingValueIterator(PyObject * mapping)443   explicit MappingValueIterator(PyObject* mapping)
444       : mapping_(mapping), keys_(MappingKeys(mapping)) {
445     if (!keys_ || PyList_Sort(keys_.get()) == -1) {
446       invalidate();
447     } else {
448       iter_.reset(PyObject_GetIter(keys_.get()));
449     }
450   }
451 
next()452   Safe_PyObjectPtr next() override {
453     Safe_PyObjectPtr result;
454     Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
455     if (key) {
456       // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
457       PyObject* elem = PyObject_GetItem(mapping_, key.get());
458       if (elem) {
459         result.reset(elem);
460       } else {
461         PyErr_SetString(PyExc_RuntimeError,
462                         "Mapping was modified during iteration over it");
463       }
464     }
465     return result;
466   }
467 
468  private:
469   PyObject* mapping_;
470   Safe_PyObjectPtr keys_;
471   Safe_PyObjectPtr iter_;
472 };
473 
474 // Iterate over a sequence, by index.
475 class SequenceValueIterator : public ValueIterator {
476  public:
SequenceValueIterator(PyObject * iterable)477   explicit SequenceValueIterator(PyObject* iterable)
478       : seq_(PySequence_Fast(iterable, "")),
479         size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
480         index_(0) {}
481 
next()482   Safe_PyObjectPtr next() override {
483     Safe_PyObjectPtr result;
484     if (index_ < size_) {
485       // PySequence_Fast_GET_ITEM returns a borrowed reference.
486       PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
487       ++index_;
488       if (elem) {
489         Py_INCREF(elem);
490         result.reset(elem);
491       }
492     }
493 
494     return result;
495   }
496 
497  private:
498   Safe_PyObjectPtr seq_;
499   const Py_ssize_t size_;
500   Py_ssize_t index_;
501 };
502 
503 // Iterator that just returns a single python object.
504 class SingleValueIterator : public ValueIterator {
505  public:
SingleValueIterator(PyObject * x)506   explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
507 
next()508   Safe_PyObjectPtr next() override { return std::move(x_); }
509 
510  private:
511   Safe_PyObjectPtr x_;
512 };
513 
514 // Returns nullptr (to raise an exception) when next() is called.  Caller
515 // should have already called PyErr_SetString.
516 class ErrorValueIterator : public ValueIterator {
517  public:
ErrorValueIterator()518   ErrorValueIterator() {}
next()519   Safe_PyObjectPtr next() override { return nullptr; }
520 };
521 
522 class AttrsValueIterator : public ValueIterator {
523  public:
AttrsValueIterator(PyObject * nested)524   explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
525     Py_INCREF(nested);
526     cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
527     if (cls_) {
528       attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
529       if (attrs_) {
530         iter_.reset(PyObject_GetIter(attrs_.get()));
531       }
532     }
533     if (!iter_ || PyErr_Occurred()) invalidate();
534   }
535 
next()536   Safe_PyObjectPtr next() override {
537     Safe_PyObjectPtr result;
538     Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
539     if (item) {
540       Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
541       result.reset(PyObject_GetAttr(nested_.get(), name.get()));
542     }
543 
544     return result;
545   }
546 
547  private:
548   Safe_PyObjectPtr nested_;
549   Safe_PyObjectPtr cls_;
550   Safe_PyObjectPtr attrs_;
551   Safe_PyObjectPtr iter_;
552 };
553 
IsSparseTensorValueType(PyObject * o)554 bool IsSparseTensorValueType(PyObject* o) {
555   PyObject* sparse_tensor_value_type =
556       GetRegisteredPyObject("SparseTensorValue");
557   if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
558     return false;
559   }
560 
561   return PyObject_TypeCheck(
562              o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
563 }
564 
565 // Returns 1 if `o` is an instance of CompositeTensor.
566 // Returns 0 otherwise.
567 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)568 bool IsCompositeTensorHelper(PyObject* o) {
569   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
570     return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
571   });
572   return check_cache->CachedLookup(o);
573 }
574 
575 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
576 // VariableSpec.
577 // Returns 0 otherwise.
578 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)579 bool IsTypeSpecHelper(PyObject* o) {
580   static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
581     int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
582     int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
583                          IsInstanceOfRegisteredType(to_check, "VariableSpec"));
584     if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
585     return static_cast<int>(is_type_spec && !is_dense_spec);
586   });
587   return check_cache->CachedLookup(o);
588 }
589 
590 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
591 // (non-TensorSpec and non-VariableSpec) TypeSpec.
592 // Returns 0 otherwise.
593 // Returns -1 if an error occurred.
IsSequenceOrCompositeHelper(PyObject * o)594 int IsSequenceOrCompositeHelper(PyObject* o) {
595   int is_sequence = IsSequenceHelper(o);
596   int is_composite = IsCompositeTensorHelper(o);
597   int is_type_spec = IsTypeSpecHelper(o);
598   if ((is_sequence == -1) || (is_composite == -1) || (is_type_spec == -1)) {
599     return -1;
600   }
601   return is_sequence || is_composite || is_type_spec;
602 }
603 
IsSequenceForDataHelper(PyObject * o)604 int IsSequenceForDataHelper(PyObject* o) {
605   return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
606          !IsSparseTensorValueType(o);
607 }
608 
GetValueIterator(PyObject * nested)609 ValueIteratorPtr GetValueIterator(PyObject* nested) {
610   if (PyDict_Check(nested)) {
611     return absl::make_unique<DictValueIterator>(nested);
612   } else if (IsMappingHelper(nested)) {
613     return absl::make_unique<MappingValueIterator>(nested);
614   } else if (IsAttrsHelper(nested)) {
615     return absl::make_unique<AttrsValueIterator>(nested);
616   } else {
617     return absl::make_unique<SequenceValueIterator>(nested);
618   }
619 }
620 
621 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)622 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
623   if (PyDict_Check(nested)) {
624     return absl::make_unique<DictValueIterator>(nested);
625   } else if (IsMappingHelper(nested)) {
626     return absl::make_unique<MappingValueIterator>(nested);
627   } else if (IsAttrsHelper(nested)) {
628     return absl::make_unique<AttrsValueIterator>(nested);
629   } else if (IsSparseTensorValueType(nested)) {
630     return absl::make_unique<SingleValueIterator>(nested);
631   } else {
632     return absl::make_unique<SequenceValueIterator>(nested);
633   }
634 }
635 
636 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)637 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
638   if (IsCompositeTensor(nested)) {
639     Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
640     if (PyErr_Occurred() || !spec) {
641       return absl::make_unique<ErrorValueIterator>();
642     }
643 
644     static char to_components[] = "_to_components";
645     static char argspec[] = "(O)";
646     Safe_PyObjectPtr components(
647         PyObject_CallMethod(spec.get(), to_components, argspec, nested));
648     if (PyErr_Occurred() || components == nullptr) {
649       return absl::make_unique<ErrorValueIterator>();
650     }
651     return absl::make_unique<SingleValueIterator>(components.get());
652   }
653 
654   if (IsTypeSpec(nested)) {
655     Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
656     if (PyErr_Occurred() || specs == nullptr) {
657       return absl::make_unique<ErrorValueIterator>();
658     }
659     return absl::make_unique<SingleValueIterator>(specs.get());
660   }
661 
662   return GetValueIterator(nested);
663 }
664 
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)665 bool FlattenHelper(
666     PyObject* nested, PyObject* list,
667     const std::function<int(PyObject*)>& is_sequence_helper,
668     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
669   // if nested is not a sequence, append itself and exit
670   int is_seq = is_sequence_helper(nested);
671   if (is_seq == -1) return false;
672   if (!is_seq) {
673     return PyList_Append(list, nested) != -1;
674   }
675 
676   ValueIteratorPtr iter = value_iterator_getter(nested);
677   if (!iter->valid()) return false;
678 
679   for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
680     if (Py_EnterRecursiveCall(" in flatten")) {
681       return false;
682     }
683     const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
684                                        value_iterator_getter);
685     Py_LeaveRecursiveCall();
686     if (!success) {
687       return false;
688     }
689   }
690   return true;
691 }
692 
693 // Sets error using keys of 'dict1' and 'dict2'.
694 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)695 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
696                            bool* is_type_error) {
697   Safe_PyObjectPtr k1(MappingKeys(dict1));
698   if (PyErr_Occurred() || k1.get() == nullptr) {
699     *error_msg =
700         ("The two dictionaries don't have the same set of keys. Failed to "
701          "fetch keys.");
702     return;
703   }
704   Safe_PyObjectPtr k2(MappingKeys(dict2));
705   if (PyErr_Occurred() || k2.get() == nullptr) {
706     *error_msg =
707         ("The two dictionaries don't have the same set of keys. Failed to "
708          "fetch keys.");
709     return;
710   }
711   *is_type_error = false;
712   *error_msg = tensorflow::strings::StrCat(
713       "The two dictionaries don't have the same set of keys. "
714       "First structure has keys ",
715       PyObjectToString(k1.get()), ", while second structure has keys ",
716       PyObjectToString(k2.get()));
717 }
718 
719 // Returns true iff there were no "internal" errors. In other words,
720 // errors that has nothing to do with structure checking.
721 // If an "internal" error occurred, the appropriate Python error will be
722 // set and the caller can propage it directly to the user.
723 //
724 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
725 // be empty.
726 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
727 // with appropriate error and sets `is_type_error` to true iff
728 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)729 bool AssertSameStructureHelper(
730     PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
731     bool* is_type_error,
732     const std::function<int(PyObject*)>& is_sequence_helper,
733     const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
734     bool check_composite_tensor_type_spec) {
735   DCHECK(error_msg);
736   DCHECK(is_type_error);
737   const bool is_seq1 = is_sequence_helper(o1);
738   const bool is_seq2 = is_sequence_helper(o2);
739   if (PyErr_Occurred()) return false;
740   if (is_seq1 != is_seq2) {
741     string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
742     string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
743     *is_type_error = false;
744     *error_msg = tensorflow::strings::StrCat(
745         "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
746         non_seq_str, "\" is not");
747     return true;
748   }
749 
750   // Got to objects that are considered non-sequences. Note that in tf.data
751   // use case lists and sparse_tensors are not considered sequences. So finished
752   // checking, structures are the same.
753   if (!is_seq1) return true;
754 
755   if (check_types) {
756     // Treat wrapped tuples as tuples.
757     tensorflow::Safe_PyObjectPtr o1_wrapped;
758     if (IsObjectProxy(o1)) {
759       o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
760       o1 = o1_wrapped.get();
761     }
762     tensorflow::Safe_PyObjectPtr o2_wrapped;
763     if (IsObjectProxy(o2)) {
764       o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
765       o2 = o2_wrapped.get();
766     }
767 
768     const PyTypeObject* type1 = o1->ob_type;
769     const PyTypeObject* type2 = o2->ob_type;
770 
771     // We treat two different namedtuples with identical name and fields
772     // as having the same type.
773     const PyObject* o1_tuple = IsNamedtuple(o1, false);
774     if (o1_tuple == nullptr) return false;
775     const PyObject* o2_tuple = IsNamedtuple(o2, false);
776     if (o2_tuple == nullptr) {
777       Py_DECREF(o1_tuple);
778       return false;
779     }
780     bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
781     Py_DECREF(o1_tuple);
782     Py_DECREF(o2_tuple);
783 
784     if (both_tuples) {
785       const PyObject* same_tuples = SameNamedtuples(o1, o2);
786       if (same_tuples == nullptr) return false;
787       bool not_same_tuples = same_tuples != Py_True;
788       Py_DECREF(same_tuples);
789       if (not_same_tuples) {
790         *is_type_error = true;
791         *error_msg = tensorflow::strings::StrCat(
792             "The two namedtuples don't have the same sequence type. "
793             "First structure ",
794             PyObjectToString(o1), " has type ", type1->tp_name,
795             ", while second structure ", PyObjectToString(o2), " has type ",
796             type2->tp_name);
797         return true;
798       }
799     } else if (type1 != type2
800                /* If both sequences are list types, don't complain. This allows
801                   one to be a list subclass (e.g. _ListWrapper used for
802                   automatic dependency tracking.) */
803                && !(PyList_Check(o1) && PyList_Check(o2))
804                /* Two mapping types will also compare equal, making _DictWrapper
805                   and dict compare equal. */
806                && !(IsMappingHelper(o1) && IsMappingHelper(o2))
807                /* For CompositeTensor & TypeSpec, we check below. */
808                && !(check_composite_tensor_type_spec &&
809                     (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
810                     (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
811       *is_type_error = true;
812       *error_msg = tensorflow::strings::StrCat(
813           "The two namedtuples don't have the same sequence type. "
814           "First structure ",
815           PyObjectToString(o1), " has type ", type1->tp_name,
816           ", while second structure ", PyObjectToString(o2), " has type ",
817           type2->tp_name);
818       return true;
819     }
820 
821     if (PyDict_Check(o1) && PyDict_Check(o2)) {
822       if (PyDict_Size(o1) != PyDict_Size(o2)) {
823         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
824         return true;
825       }
826 
827       PyObject* key;
828       Py_ssize_t pos = 0;
829       while (PyDict_Next(o1, &pos, &key, nullptr)) {
830         if (PyDict_GetItem(o2, key) == nullptr) {
831           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
832           return true;
833         }
834       }
835     } else if (IsMappingHelper(o1)) {
836       // Fallback for custom mapping types. Instead of using PyDict methods
837       // which stay in C, we call iter(o1).
838       if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
839         SetDifferentKeysError(o1, o2, error_msg, is_type_error);
840         return true;
841       }
842 
843       Safe_PyObjectPtr iter(PyObject_GetIter(o1));
844       PyObject* key;
845       while ((key = PyIter_Next(iter.get())) != nullptr) {
846         if (!PyMapping_HasKey(o2, key)) {
847           SetDifferentKeysError(o1, o2, error_msg, is_type_error);
848           Py_DECREF(key);
849           return true;
850         }
851         Py_DECREF(key);
852       }
853     }
854   }
855 
856   if (check_composite_tensor_type_spec &&
857       (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
858     Safe_PyObjectPtr owned_type_spec_1;
859     PyObject* type_spec_1 = o1;
860     if (IsCompositeTensor(o1)) {
861       owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
862       type_spec_1 = owned_type_spec_1.get();
863     }
864 
865     Safe_PyObjectPtr owned_type_spec_2;
866     PyObject* type_spec_2 = o2;
867     if (IsCompositeTensor(o2)) {
868       owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
869       type_spec_2 = owned_type_spec_2.get();
870     }
871 
872     // Two composite tensors are considered to have the same structure if
873     // there is some type spec that is compatible with both of them.  Thus,
874     // we use most_specific_compatible_type(), and check if it raises an
875     // exception.  We do *not* use is_compatible_with, since that would
876     // prevent us from e.g. using a cond statement where the two sides have
877     // different shapes.
878     static char compatible_type[] = "most_specific_compatible_type";
879     static char argspec[] = "(O)";
880     Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
881         type_spec_1, compatible_type, argspec, type_spec_2));
882     if (PyErr_Occurred() || struct_compatible == nullptr) {
883       PyErr_Clear();
884       *is_type_error = false;
885       *error_msg = tensorflow::strings::StrCat(
886           "Incompatible CompositeTensor TypeSpecs: ",
887           PyObjectToString(type_spec_1), " vs. ",
888           PyObjectToString(type_spec_2));
889       return true;
890     }
891   }
892 
893   ValueIteratorPtr iter1 = value_iterator_getter(o1);
894   ValueIteratorPtr iter2 = value_iterator_getter(o2);
895 
896   if (!iter1->valid() || !iter2->valid()) return false;
897 
898   while (true) {
899     Safe_PyObjectPtr v1 = iter1->next();
900     Safe_PyObjectPtr v2 = iter2->next();
901     if (v1 && v2) {
902       if (Py_EnterRecursiveCall(" in assert_same_structure")) {
903         return false;
904       }
905       bool no_internal_errors = AssertSameStructureHelper(
906           v1.get(), v2.get(), check_types, error_msg, is_type_error,
907           is_sequence_helper, value_iterator_getter,
908           check_composite_tensor_type_spec);
909       Py_LeaveRecursiveCall();
910       if (!no_internal_errors) return false;
911       if (!error_msg->empty()) return true;
912     } else if (!v1 && !v2) {
913       // Done with all recursive calls. Structure matched.
914       return true;
915     } else {
916       *is_type_error = false;
917       *error_msg = tensorflow::strings::StrCat(
918           "The two structures don't have the same number of elements. ",
919           "First structure: ", PyObjectToString(o1),
920           ". Second structure: ", PyObjectToString(o2));
921       return true;
922     }
923   }
924 }
925 
926 }  // namespace
927 
IsSequence(PyObject * o)928 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)929 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)930 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)931 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)932 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)933 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)934 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)935 bool IsResourceVariable(PyObject* o) {
936   return IsResourceVariableHelper(o) == 1;
937 }
IsOwnedIterator(PyObject * o)938 bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; }
IsVariable(PyObject * o)939 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)940 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)941 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
942 
IsTuple(PyObject * o)943 bool IsTuple(PyObject* o) {
944   tensorflow::Safe_PyObjectPtr wrapped;
945   if (IsObjectProxy(o)) {
946     wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
947     o = wrapped.get();
948   }
949   return PyTuple_Check(o);
950 }
951 
952 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
953 // and while we're at it give them consistent behavior by making sure the
954 // returned value is a list.
955 //
956 // As with PyMapping_Keys, returns a new reference.
957 //
958 // On failure, returns nullptr.
MappingKeys(PyObject * o)959 PyObject* MappingKeys(PyObject* o) {
960 #if PY_MAJOR_VERSION >= 3
961   return PyMapping_Keys(o);
962 #else
963   static char key_method_name[] = "keys";
964   Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
965   if (PyErr_Occurred() || raw_result.get() == nullptr) {
966     return nullptr;
967   }
968   return PySequence_Fast(
969       raw_result.get(),
970       "The '.keys()' method of a custom mapping returned a non-sequence.");
971 #endif
972 }
973 
Flatten(PyObject * nested,bool expand_composites)974 PyObject* Flatten(PyObject* nested, bool expand_composites) {
975   PyObject* list = PyList_New(0);
976   const std::function<int(PyObject*)>& is_sequence_helper =
977       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
978   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
979       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
980   if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
981     return list;
982   } else {
983     Py_DECREF(list);
984     return nullptr;
985   }
986 }
987 
IsSequenceOrComposite(PyObject * o)988 bool IsSequenceOrComposite(PyObject* o) {
989   return IsSequenceOrCompositeHelper(o) == 1;
990 }
991 
IsCompositeTensor(PyObject * o)992 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
993 
IsTypeSpec(PyObject * o)994 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
995 
IsSequenceForData(PyObject * o)996 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
997 
FlattenForData(PyObject * nested)998 PyObject* FlattenForData(PyObject* nested) {
999   PyObject* list = PyList_New(0);
1000   if (FlattenHelper(nested, list, IsSequenceForDataHelper,
1001                     GetValueIteratorForData)) {
1002     return list;
1003   } else {
1004     Py_DECREF(list);
1005     return nullptr;
1006   }
1007 }
1008 
IsNamedtuple(PyObject * o,bool strict)1009 PyObject* IsNamedtuple(PyObject* o, bool strict) {
1010   // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1011   // require some unwrapping if we want to treat them like the objects they're
1012   // wrapping.
1013   tensorflow::Safe_PyObjectPtr o_wrapped;
1014   if (IsObjectProxy(o)) {
1015     o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1016     o = o_wrapped.get();
1017   }
1018 
1019   // Must be subclass of tuple
1020   if (!PyTuple_Check(o)) {
1021     Py_RETURN_FALSE;
1022   }
1023 
1024   // If strict, o.__class__.__base__ must be tuple
1025   if (strict) {
1026     PyObject* klass = PyObject_GetAttrString(o, "__class__");
1027     if (klass == nullptr) return nullptr;
1028     PyObject* base = PyObject_GetAttrString(klass, "__base__");
1029     Py_DECREF(klass);
1030     if (base == nullptr) return nullptr;
1031 
1032     const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1033     // built-in object types are singletons
1034     bool tuple_base = base_type == &PyTuple_Type;
1035     Py_DECREF(base);
1036     if (!tuple_base) {
1037       Py_RETURN_FALSE;
1038     }
1039   }
1040 
1041   // o must have attribute '_fields' and every element in
1042   // '_fields' must be a string.
1043   int has_fields = PyObject_HasAttrString(o, "_fields");
1044   if (!has_fields) {
1045     Py_RETURN_FALSE;
1046   }
1047 
1048   Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1049   int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1050   if (is_instance == 0) {
1051     Py_RETURN_FALSE;
1052   } else if (is_instance == -1) {
1053     return nullptr;
1054   }
1055 
1056   Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1057   const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1058   for (Py_ssize_t i = 0; i < s; ++i) {
1059     // PySequence_Fast_GET_ITEM returns borrowed ref
1060     PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1061     if (!IsString(elem)) {
1062       Py_RETURN_FALSE;
1063     }
1064   }
1065 
1066   Py_RETURN_TRUE;
1067 }
1068 
SameNamedtuples(PyObject * o1,PyObject * o2)1069 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1070   Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1071   Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1072   if (f1 == nullptr || f2 == nullptr) {
1073     PyErr_SetString(
1074         PyExc_RuntimeError,
1075         "Expected namedtuple-like objects (that have _fields attr)");
1076     return nullptr;
1077   }
1078 
1079   if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1080     Py_RETURN_FALSE;
1081   }
1082 
1083   if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1084     Py_RETURN_TRUE;
1085   } else {
1086     Py_RETURN_FALSE;
1087   }
1088 }
1089 
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1090 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1091                               bool expand_composites) {
1092   const std::function<int(PyObject*)>& is_sequence_helper =
1093       expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
1094   const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1095       expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1096   const bool check_composite_tensor_type_spec = expand_composites;
1097   string error_msg;
1098   bool is_type_error = false;
1099   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1100                             is_sequence_helper, get_value_iterator,
1101                             check_composite_tensor_type_spec);
1102   if (PyErr_Occurred()) {
1103     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1104     // from custom mappings).
1105     return nullptr;
1106   }
1107   if (!error_msg.empty()) {
1108     PyErr_SetString(
1109         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1110         tensorflow::strings::StrCat(
1111             "The two structures don't have the same nested structure.\n\n",
1112             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1113             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1114             .c_str());
1115     return nullptr;
1116   }
1117   Py_RETURN_NONE;
1118 }
1119 
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1120 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1121                                      bool check_types) {
1122   string error_msg;
1123   bool is_type_error = false;
1124   AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1125                             IsSequenceForDataHelper, GetValueIterator, false);
1126   if (PyErr_Occurred()) {
1127     // Don't hide Python exceptions while checking (e.g. errors fetching keys
1128     // from custom mappings).
1129     return nullptr;
1130   }
1131   if (!error_msg.empty()) {
1132     PyErr_SetString(
1133         is_type_error ? PyExc_TypeError : PyExc_ValueError,
1134         tensorflow::strings::StrCat(
1135             "The two structures don't have the same nested structure.\n\n",
1136             "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1137             PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1138             .c_str());
1139     return nullptr;
1140   }
1141   Py_RETURN_NONE;
1142 }
1143 
1144 }  // namespace swig
1145 }  // namespace tensorflow
1146