1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28
29 namespace tensorflow {
30 namespace swig {
31
32 namespace {
33 string PyObjectToString(PyObject* o);
34 } // namespace
35
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37 static auto* m = new std::unordered_map<string, PyObject*>();
38 return m;
39 }
40
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42 const auto* m = RegisteredPyObjectMap();
43 auto it = m->find(name);
44 if (it == m->end()) {
45 PyErr_SetString(PyExc_TypeError,
46 tensorflow::strings::StrCat("No object with name ", name,
47 " has been registered.")
48 .c_str());
49 return nullptr;
50 }
51 return it->second;
52 }
53
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55 if (!PyType_Check(type)) {
56 PyErr_SetString(PyExc_TypeError,
57 tensorflow::strings::StrCat("Expecting a type, got ",
58 Py_TYPE(type)->tp_name)
59 .c_str());
60 return nullptr;
61 }
62 return RegisterPyObject(type_name, type);
63 }
64
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66 string key;
67 if (PyBytes_Check(name)) {
68 key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70 } else if (PyUnicode_Check(name)) {
71 key = PyUnicode_AsUTF8(name);
72 #endif
73 } else {
74 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75 "Expected name to be a str, got",
76 PyObjectToString(name))
77 .c_str());
78 return nullptr;
79 }
80
81 auto* m = RegisteredPyObjectMap();
82 if (m->find(key) != m->end()) {
83 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84 "Value already registered for ", key)
85 .c_str());
86 return nullptr;
87 }
88
89 Py_INCREF(value);
90 m->emplace(key, value);
91
92 Py_RETURN_NONE;
93 }
94
95 namespace {
96 const int kMaxItemsInCache = 1024;
97
98 bool WarnedThatSetIsNotSequence = false;
99
IsString(PyObject * o)100 bool IsString(PyObject* o) {
101 return PyBytes_Check(o) ||
102 #if PY_MAJOR_VERSION < 3
103 PyString_Check(o) ||
104 #endif
105 PyUnicode_Check(o);
106 }
107
108 // Equivalent to Python's 'o.__class__.__name__'
109 // Note that '__class__' attribute is set only in new-style classes.
110 // A lot of tensorflow code uses __class__ without checks, so it seems like
111 // we only support new-style classes.
GetClassName(PyObject * o)112 StringPiece GetClassName(PyObject* o) {
113 // __class__ is equivalent to type() for new style classes.
114 // type() is equivalent to PyObject_Type()
115 // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
116 // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
117 // we don't need here.
118 PyTypeObject* type = o->ob_type;
119
120 // __name__ is the value of `tp_name` after the last '.'
121 // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
122 StringPiece name(type->tp_name);
123 size_t pos = name.rfind('.');
124 if (pos != StringPiece::npos) {
125 name.remove_prefix(pos + 1);
126 }
127 return name;
128 }
129
PyObjectToString(PyObject * o)130 string PyObjectToString(PyObject* o) {
131 if (o == nullptr) {
132 return "<null object>";
133 }
134 PyObject* str = PyObject_Str(o);
135 if (str) {
136 #if PY_MAJOR_VERSION < 3
137 string s(PyString_AS_STRING(str));
138 #else
139 string s(PyUnicode_AsUTF8(str));
140 #endif
141 Py_DECREF(str);
142 return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
143 } else {
144 return "<failed to execute str() on object>";
145 }
146 }
147
148 class CachedTypeCheck {
149 public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)150 explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
151 : ternary_predicate_(std::move(ternary_predicate)) {}
152
~CachedTypeCheck()153 ~CachedTypeCheck() {
154 mutex_lock l(type_to_sequence_map_mu_);
155 for (const auto& pair : type_to_sequence_map_) {
156 Py_DECREF(pair.first);
157 }
158 }
159
160 // Caches successful executions of the one-argument (PyObject*) callable
161 // "ternary_predicate" based on the type of "o". -1 from the callable
162 // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
163 // does not match the predicate, and 1 indicates that it does. Used to avoid
164 // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)165 int CachedLookup(PyObject* o) {
166 // Try not to return to Python - see if the type has already been seen
167 // before.
168
169 auto* type = Py_TYPE(o);
170
171 {
172 tf_shared_lock l(type_to_sequence_map_mu_);
173 auto it = type_to_sequence_map_.find(type);
174 if (it != type_to_sequence_map_.end()) {
175 return it->second;
176 }
177 }
178
179 int check_result = ternary_predicate_(o);
180
181 if (check_result == -1) {
182 return -1; // Type check error, not cached.
183 }
184
185 // NOTE: This is never decref'd as long as the object lives, which is likely
186 // forever, but we don't want the type to get deleted as long as it is in
187 // the map. This should not be too much of a leak, as there should only be a
188 // relatively small number of types in the map, and an even smaller number
189 // that are eligible for decref. As a precaution, we limit the size of the
190 // map to 1024.
191 {
192 mutex_lock l(type_to_sequence_map_mu_);
193 if (type_to_sequence_map_.size() < kMaxItemsInCache) {
194 Py_INCREF(type);
195 auto insert_result = type_to_sequence_map_.insert({type, check_result});
196 if (!insert_result.second) {
197 // The type was added to the cache by a concurrent thread after we
198 // looked it up above.
199 Py_DECREF(type);
200 }
201 }
202 }
203
204 return check_result;
205 }
206
207 private:
208 std::function<int(PyObject*)> ternary_predicate_;
209 mutex type_to_sequence_map_mu_;
210 std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
211 TF_GUARDED_BY(type_to_sequence_map_mu_);
212 };
213
214 // Returns 1 if 'obj' is an instance of 'type_name'
215 // Returns 0 otherwise.
216 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)217 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
218 PyObject* type_obj = GetRegisteredPyObject(type_name);
219 if (TF_PREDICT_FALSE(type_obj == nullptr)) {
220 PyErr_SetString(PyExc_RuntimeError,
221 tensorflow::strings::StrCat(
222 type_name,
223 " type has not been set. "
224 "Please register the type with the identifier \"",
225 type_name, "\" using RegisterType.")
226 .c_str());
227 return -1;
228 }
229 return PyObject_IsInstance(obj, type_obj);
230 }
231
232 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
233 // Returns 0 otherwise.
234 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)235 int IsMappingHelper(PyObject* o) {
236 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
237 return IsInstanceOfRegisteredType(to_check, "Mapping");
238 });
239 if (PyDict_Check(o)) return true;
240 return check_cache->CachedLookup(o);
241 }
242
243 // Returns 1 if `o` is considered a mutable mapping for the purposes of
244 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)245 int IsMutableMappingHelper(PyObject* o) {
246 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
247 return IsInstanceOfRegisteredType(to_check, "MutableMapping");
248 });
249 if (PyDict_Check(o)) return true;
250 return check_cache->CachedLookup(o);
251 }
252
253 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
254 // Returns 0 otherwise.
255 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)256 int IsMappingViewHelper(PyObject* o) {
257 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
258 return IsInstanceOfRegisteredType(to_check, "MappingView");
259 });
260 return check_cache->CachedLookup(o);
261 }
262
263 // Returns 1 if `o` is considered an object proxy
264 // Returns 0 otherwise.
265 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)266 int IsObjectProxy(PyObject* o) {
267 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
268 return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
269 });
270 return check_cache->CachedLookup(o);
271 }
272
273 // Returns 1 if `o` is an instance of attrs-decorated class.
274 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)275 int IsAttrsHelper(PyObject* o) {
276 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
277 Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
278 if (cls) {
279 return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
280 }
281
282 // PyObject_GetAttrString returns null on error
283 PyErr_Clear();
284 return 0;
285 });
286 return check_cache->CachedLookup(o);
287 }
288
289 // Returns 1 if `o` is an object of type IndexedSlices.
290 // Returns 0 otherwise.
291 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)292 int IsIndexedSlicesHelper(PyObject* o) {
293 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
294 return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
295 });
296 return check_cache->CachedLookup(o);
297 }
298
299 // Returns 1 if `o` is a Tensor.
300 // Returns 0 otherwise.
301 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)302 int IsTensorHelper(PyObject* o) {
303 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
304 return IsInstanceOfRegisteredType(to_check, "Tensor");
305 });
306 return check_cache->CachedLookup(o);
307 }
308
309 // Returns 1 if `o` is an EagerTensor.
310 // Returns 0 otherwise.
311 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)312 int IsEagerTensorHelper(PyObject* o) {
313 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
314 return IsInstanceOfRegisteredType(to_check, "EagerTensor");
315 });
316 return check_cache->CachedLookup(o);
317 }
318
319 // Returns 1 if `o` is a ResourceVariable.
320 // Returns 0 otherwise.
321 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)322 int IsResourceVariableHelper(PyObject* o) {
323 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
324 return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
325 });
326 return check_cache->CachedLookup(o);
327 }
328
329 // Returns 1 if `o` is a OwnedIterator.
330 // Returns 0 otherwise.
331 // Returns -1 if an error occurred.
IsOwnedIteratorHelper(PyObject * o)332 int IsOwnedIteratorHelper(PyObject* o) {
333 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
334 return IsInstanceOfRegisteredType(to_check, "OwnedIterator");
335 });
336 return check_cache->CachedLookup(o);
337 }
338
339 // Returns 1 if `o` is a ResourceVariable.
340 // Returns 0 otherwise.
341 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)342 int IsVariableHelper(PyObject* o) {
343 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
344 return IsInstanceOfRegisteredType(to_check, "Variable");
345 });
346 return check_cache->CachedLookup(o);
347 }
348
349 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
350 // Returns 0 otherwise.
351 // Returns -1 if an error occurred.
IsSequenceHelper(PyObject * o)352 int IsSequenceHelper(PyObject* o) {
353 // We treat dicts and other mappings as special cases of sequences.
354 if (IsMappingHelper(o)) return true;
355 if (IsMappingViewHelper(o)) return true;
356 if (IsAttrsHelper(o)) return true;
357 if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
358 LOG(WARNING) << "Sets are not currently considered sequences, "
359 "but this may change in the future, "
360 "so consider avoiding using them.";
361 WarnedThatSetIsNotSequence = true;
362 }
363 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
364 int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
365
366 // Don't cache a failed is_instance check.
367 if (is_instance == -1) return -1;
368
369 return static_cast<int>(is_instance != 0 && !IsString(to_check));
370 });
371 return check_cache->CachedLookup(o);
372 }
373
374 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
375 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)376 int IsDispatchableHelper(PyObject* o) {
377 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
378 return PyObject_HasAttrString(
379 reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
380 });
381 return check_cache->CachedLookup(o);
382 }
383
384 // ValueIterator interface
385 class ValueIterator {
386 public:
~ValueIterator()387 virtual ~ValueIterator() {}
388 virtual Safe_PyObjectPtr next() = 0;
389
valid() const390 bool valid() const { return is_valid_; }
391
392 protected:
invalidate()393 void invalidate() { is_valid_ = false; }
394
395 private:
396 bool is_valid_ = true;
397 };
398
399 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
400
401 // Iterate through dictionaries in a deterministic order by sorting the
402 // keys. Notice this means that we ignore the original order of
403 // `OrderedDict` instances. This is intentional, to avoid potential
404 // bugs caused by mixing ordered and plain dicts (e.g., flattening
405 // a dict but using a corresponding `OrderedDict` to pack it back).
406 class DictValueIterator : public ValueIterator {
407 public:
DictValueIterator(PyObject * dict)408 explicit DictValueIterator(PyObject* dict)
409 : dict_(dict), keys_(PyDict_Keys(dict)) {
410 if (PyList_Sort(keys_.get()) == -1) {
411 invalidate();
412 } else {
413 iter_.reset(PyObject_GetIter(keys_.get()));
414 }
415 }
416
next()417 Safe_PyObjectPtr next() override {
418 Safe_PyObjectPtr result;
419 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
420 if (key) {
421 // PyDict_GetItem returns a borrowed reference.
422 PyObject* elem = PyDict_GetItem(dict_, key.get());
423 if (elem) {
424 Py_INCREF(elem);
425 result.reset(elem);
426 } else {
427 PyErr_SetString(PyExc_RuntimeError,
428 "Dictionary was modified during iteration over it");
429 }
430 }
431 return result;
432 }
433
434 private:
435 PyObject* dict_;
436 Safe_PyObjectPtr keys_;
437 Safe_PyObjectPtr iter_;
438 };
439
440 // Iterate over mapping objects by sorting the keys first
441 class MappingValueIterator : public ValueIterator {
442 public:
MappingValueIterator(PyObject * mapping)443 explicit MappingValueIterator(PyObject* mapping)
444 : mapping_(mapping), keys_(MappingKeys(mapping)) {
445 if (!keys_ || PyList_Sort(keys_.get()) == -1) {
446 invalidate();
447 } else {
448 iter_.reset(PyObject_GetIter(keys_.get()));
449 }
450 }
451
next()452 Safe_PyObjectPtr next() override {
453 Safe_PyObjectPtr result;
454 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
455 if (key) {
456 // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
457 PyObject* elem = PyObject_GetItem(mapping_, key.get());
458 if (elem) {
459 result.reset(elem);
460 } else {
461 PyErr_SetString(PyExc_RuntimeError,
462 "Mapping was modified during iteration over it");
463 }
464 }
465 return result;
466 }
467
468 private:
469 PyObject* mapping_;
470 Safe_PyObjectPtr keys_;
471 Safe_PyObjectPtr iter_;
472 };
473
474 // Iterate over a sequence, by index.
475 class SequenceValueIterator : public ValueIterator {
476 public:
SequenceValueIterator(PyObject * iterable)477 explicit SequenceValueIterator(PyObject* iterable)
478 : seq_(PySequence_Fast(iterable, "")),
479 size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
480 index_(0) {}
481
next()482 Safe_PyObjectPtr next() override {
483 Safe_PyObjectPtr result;
484 if (index_ < size_) {
485 // PySequence_Fast_GET_ITEM returns a borrowed reference.
486 PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
487 ++index_;
488 if (elem) {
489 Py_INCREF(elem);
490 result.reset(elem);
491 }
492 }
493
494 return result;
495 }
496
497 private:
498 Safe_PyObjectPtr seq_;
499 const Py_ssize_t size_;
500 Py_ssize_t index_;
501 };
502
503 // Iterator that just returns a single python object.
504 class SingleValueIterator : public ValueIterator {
505 public:
SingleValueIterator(PyObject * x)506 explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
507
next()508 Safe_PyObjectPtr next() override { return std::move(x_); }
509
510 private:
511 Safe_PyObjectPtr x_;
512 };
513
514 // Returns nullptr (to raise an exception) when next() is called. Caller
515 // should have already called PyErr_SetString.
516 class ErrorValueIterator : public ValueIterator {
517 public:
ErrorValueIterator()518 ErrorValueIterator() {}
next()519 Safe_PyObjectPtr next() override { return nullptr; }
520 };
521
522 class AttrsValueIterator : public ValueIterator {
523 public:
AttrsValueIterator(PyObject * nested)524 explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
525 Py_INCREF(nested);
526 cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
527 if (cls_) {
528 attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
529 if (attrs_) {
530 iter_.reset(PyObject_GetIter(attrs_.get()));
531 }
532 }
533 if (!iter_ || PyErr_Occurred()) invalidate();
534 }
535
next()536 Safe_PyObjectPtr next() override {
537 Safe_PyObjectPtr result;
538 Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
539 if (item) {
540 Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
541 result.reset(PyObject_GetAttr(nested_.get(), name.get()));
542 }
543
544 return result;
545 }
546
547 private:
548 Safe_PyObjectPtr nested_;
549 Safe_PyObjectPtr cls_;
550 Safe_PyObjectPtr attrs_;
551 Safe_PyObjectPtr iter_;
552 };
553
IsSparseTensorValueType(PyObject * o)554 bool IsSparseTensorValueType(PyObject* o) {
555 PyObject* sparse_tensor_value_type =
556 GetRegisteredPyObject("SparseTensorValue");
557 if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
558 return false;
559 }
560
561 return PyObject_TypeCheck(
562 o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
563 }
564
565 // Returns 1 if `o` is an instance of CompositeTensor.
566 // Returns 0 otherwise.
567 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)568 bool IsCompositeTensorHelper(PyObject* o) {
569 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
570 return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
571 });
572 return check_cache->CachedLookup(o);
573 }
574
575 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
576 // VariableSpec.
577 // Returns 0 otherwise.
578 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)579 bool IsTypeSpecHelper(PyObject* o) {
580 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
581 int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
582 int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
583 IsInstanceOfRegisteredType(to_check, "VariableSpec"));
584 if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
585 return static_cast<int>(is_type_spec && !is_dense_spec);
586 });
587 return check_cache->CachedLookup(o);
588 }
589
590 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
591 // (non-TensorSpec and non-VariableSpec) TypeSpec.
592 // Returns 0 otherwise.
593 // Returns -1 if an error occurred.
IsSequenceOrCompositeHelper(PyObject * o)594 int IsSequenceOrCompositeHelper(PyObject* o) {
595 int is_sequence = IsSequenceHelper(o);
596 int is_composite = IsCompositeTensorHelper(o);
597 int is_type_spec = IsTypeSpecHelper(o);
598 if ((is_sequence == -1) || (is_composite == -1) || (is_type_spec == -1)) {
599 return -1;
600 }
601 return is_sequence || is_composite || is_type_spec;
602 }
603
IsSequenceForDataHelper(PyObject * o)604 int IsSequenceForDataHelper(PyObject* o) {
605 return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
606 !IsSparseTensorValueType(o);
607 }
608
GetValueIterator(PyObject * nested)609 ValueIteratorPtr GetValueIterator(PyObject* nested) {
610 if (PyDict_Check(nested)) {
611 return absl::make_unique<DictValueIterator>(nested);
612 } else if (IsMappingHelper(nested)) {
613 return absl::make_unique<MappingValueIterator>(nested);
614 } else if (IsAttrsHelper(nested)) {
615 return absl::make_unique<AttrsValueIterator>(nested);
616 } else {
617 return absl::make_unique<SequenceValueIterator>(nested);
618 }
619 }
620
621 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)622 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
623 if (PyDict_Check(nested)) {
624 return absl::make_unique<DictValueIterator>(nested);
625 } else if (IsMappingHelper(nested)) {
626 return absl::make_unique<MappingValueIterator>(nested);
627 } else if (IsAttrsHelper(nested)) {
628 return absl::make_unique<AttrsValueIterator>(nested);
629 } else if (IsSparseTensorValueType(nested)) {
630 return absl::make_unique<SingleValueIterator>(nested);
631 } else {
632 return absl::make_unique<SequenceValueIterator>(nested);
633 }
634 }
635
636 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)637 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
638 if (IsCompositeTensor(nested)) {
639 Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
640 if (PyErr_Occurred() || !spec) {
641 return absl::make_unique<ErrorValueIterator>();
642 }
643
644 static char to_components[] = "_to_components";
645 static char argspec[] = "(O)";
646 Safe_PyObjectPtr components(
647 PyObject_CallMethod(spec.get(), to_components, argspec, nested));
648 if (PyErr_Occurred() || components == nullptr) {
649 return absl::make_unique<ErrorValueIterator>();
650 }
651 return absl::make_unique<SingleValueIterator>(components.get());
652 }
653
654 if (IsTypeSpec(nested)) {
655 Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
656 if (PyErr_Occurred() || specs == nullptr) {
657 return absl::make_unique<ErrorValueIterator>();
658 }
659 return absl::make_unique<SingleValueIterator>(specs.get());
660 }
661
662 return GetValueIterator(nested);
663 }
664
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)665 bool FlattenHelper(
666 PyObject* nested, PyObject* list,
667 const std::function<int(PyObject*)>& is_sequence_helper,
668 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
669 // if nested is not a sequence, append itself and exit
670 int is_seq = is_sequence_helper(nested);
671 if (is_seq == -1) return false;
672 if (!is_seq) {
673 return PyList_Append(list, nested) != -1;
674 }
675
676 ValueIteratorPtr iter = value_iterator_getter(nested);
677 if (!iter->valid()) return false;
678
679 for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
680 if (Py_EnterRecursiveCall(" in flatten")) {
681 return false;
682 }
683 const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
684 value_iterator_getter);
685 Py_LeaveRecursiveCall();
686 if (!success) {
687 return false;
688 }
689 }
690 return true;
691 }
692
693 // Sets error using keys of 'dict1' and 'dict2'.
694 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)695 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
696 bool* is_type_error) {
697 Safe_PyObjectPtr k1(MappingKeys(dict1));
698 if (PyErr_Occurred() || k1.get() == nullptr) {
699 *error_msg =
700 ("The two dictionaries don't have the same set of keys. Failed to "
701 "fetch keys.");
702 return;
703 }
704 Safe_PyObjectPtr k2(MappingKeys(dict2));
705 if (PyErr_Occurred() || k2.get() == nullptr) {
706 *error_msg =
707 ("The two dictionaries don't have the same set of keys. Failed to "
708 "fetch keys.");
709 return;
710 }
711 *is_type_error = false;
712 *error_msg = tensorflow::strings::StrCat(
713 "The two dictionaries don't have the same set of keys. "
714 "First structure has keys ",
715 PyObjectToString(k1.get()), ", while second structure has keys ",
716 PyObjectToString(k2.get()));
717 }
718
719 // Returns true iff there were no "internal" errors. In other words,
720 // errors that has nothing to do with structure checking.
721 // If an "internal" error occurred, the appropriate Python error will be
722 // set and the caller can propage it directly to the user.
723 //
724 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
725 // be empty.
726 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
727 // with appropriate error and sets `is_type_error` to true iff
728 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_sequence_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)729 bool AssertSameStructureHelper(
730 PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
731 bool* is_type_error,
732 const std::function<int(PyObject*)>& is_sequence_helper,
733 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
734 bool check_composite_tensor_type_spec) {
735 DCHECK(error_msg);
736 DCHECK(is_type_error);
737 const bool is_seq1 = is_sequence_helper(o1);
738 const bool is_seq2 = is_sequence_helper(o2);
739 if (PyErr_Occurred()) return false;
740 if (is_seq1 != is_seq2) {
741 string seq_str = is_seq1 ? PyObjectToString(o1) : PyObjectToString(o2);
742 string non_seq_str = is_seq1 ? PyObjectToString(o2) : PyObjectToString(o1);
743 *is_type_error = false;
744 *error_msg = tensorflow::strings::StrCat(
745 "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
746 non_seq_str, "\" is not");
747 return true;
748 }
749
750 // Got to objects that are considered non-sequences. Note that in tf.data
751 // use case lists and sparse_tensors are not considered sequences. So finished
752 // checking, structures are the same.
753 if (!is_seq1) return true;
754
755 if (check_types) {
756 // Treat wrapped tuples as tuples.
757 tensorflow::Safe_PyObjectPtr o1_wrapped;
758 if (IsObjectProxy(o1)) {
759 o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
760 o1 = o1_wrapped.get();
761 }
762 tensorflow::Safe_PyObjectPtr o2_wrapped;
763 if (IsObjectProxy(o2)) {
764 o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
765 o2 = o2_wrapped.get();
766 }
767
768 const PyTypeObject* type1 = o1->ob_type;
769 const PyTypeObject* type2 = o2->ob_type;
770
771 // We treat two different namedtuples with identical name and fields
772 // as having the same type.
773 const PyObject* o1_tuple = IsNamedtuple(o1, false);
774 if (o1_tuple == nullptr) return false;
775 const PyObject* o2_tuple = IsNamedtuple(o2, false);
776 if (o2_tuple == nullptr) {
777 Py_DECREF(o1_tuple);
778 return false;
779 }
780 bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
781 Py_DECREF(o1_tuple);
782 Py_DECREF(o2_tuple);
783
784 if (both_tuples) {
785 const PyObject* same_tuples = SameNamedtuples(o1, o2);
786 if (same_tuples == nullptr) return false;
787 bool not_same_tuples = same_tuples != Py_True;
788 Py_DECREF(same_tuples);
789 if (not_same_tuples) {
790 *is_type_error = true;
791 *error_msg = tensorflow::strings::StrCat(
792 "The two namedtuples don't have the same sequence type. "
793 "First structure ",
794 PyObjectToString(o1), " has type ", type1->tp_name,
795 ", while second structure ", PyObjectToString(o2), " has type ",
796 type2->tp_name);
797 return true;
798 }
799 } else if (type1 != type2
800 /* If both sequences are list types, don't complain. This allows
801 one to be a list subclass (e.g. _ListWrapper used for
802 automatic dependency tracking.) */
803 && !(PyList_Check(o1) && PyList_Check(o2))
804 /* Two mapping types will also compare equal, making _DictWrapper
805 and dict compare equal. */
806 && !(IsMappingHelper(o1) && IsMappingHelper(o2))
807 /* For CompositeTensor & TypeSpec, we check below. */
808 && !(check_composite_tensor_type_spec &&
809 (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
810 (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
811 *is_type_error = true;
812 *error_msg = tensorflow::strings::StrCat(
813 "The two namedtuples don't have the same sequence type. "
814 "First structure ",
815 PyObjectToString(o1), " has type ", type1->tp_name,
816 ", while second structure ", PyObjectToString(o2), " has type ",
817 type2->tp_name);
818 return true;
819 }
820
821 if (PyDict_Check(o1) && PyDict_Check(o2)) {
822 if (PyDict_Size(o1) != PyDict_Size(o2)) {
823 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
824 return true;
825 }
826
827 PyObject* key;
828 Py_ssize_t pos = 0;
829 while (PyDict_Next(o1, &pos, &key, nullptr)) {
830 if (PyDict_GetItem(o2, key) == nullptr) {
831 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
832 return true;
833 }
834 }
835 } else if (IsMappingHelper(o1)) {
836 // Fallback for custom mapping types. Instead of using PyDict methods
837 // which stay in C, we call iter(o1).
838 if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
839 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
840 return true;
841 }
842
843 Safe_PyObjectPtr iter(PyObject_GetIter(o1));
844 PyObject* key;
845 while ((key = PyIter_Next(iter.get())) != nullptr) {
846 if (!PyMapping_HasKey(o2, key)) {
847 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
848 Py_DECREF(key);
849 return true;
850 }
851 Py_DECREF(key);
852 }
853 }
854 }
855
856 if (check_composite_tensor_type_spec &&
857 (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
858 Safe_PyObjectPtr owned_type_spec_1;
859 PyObject* type_spec_1 = o1;
860 if (IsCompositeTensor(o1)) {
861 owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
862 type_spec_1 = owned_type_spec_1.get();
863 }
864
865 Safe_PyObjectPtr owned_type_spec_2;
866 PyObject* type_spec_2 = o2;
867 if (IsCompositeTensor(o2)) {
868 owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
869 type_spec_2 = owned_type_spec_2.get();
870 }
871
872 // Two composite tensors are considered to have the same structure if
873 // there is some type spec that is compatible with both of them. Thus,
874 // we use most_specific_compatible_type(), and check if it raises an
875 // exception. We do *not* use is_compatible_with, since that would
876 // prevent us from e.g. using a cond statement where the two sides have
877 // different shapes.
878 static char compatible_type[] = "most_specific_compatible_type";
879 static char argspec[] = "(O)";
880 Safe_PyObjectPtr struct_compatible(PyObject_CallMethod(
881 type_spec_1, compatible_type, argspec, type_spec_2));
882 if (PyErr_Occurred() || struct_compatible == nullptr) {
883 PyErr_Clear();
884 *is_type_error = false;
885 *error_msg = tensorflow::strings::StrCat(
886 "Incompatible CompositeTensor TypeSpecs: ",
887 PyObjectToString(type_spec_1), " vs. ",
888 PyObjectToString(type_spec_2));
889 return true;
890 }
891 }
892
893 ValueIteratorPtr iter1 = value_iterator_getter(o1);
894 ValueIteratorPtr iter2 = value_iterator_getter(o2);
895
896 if (!iter1->valid() || !iter2->valid()) return false;
897
898 while (true) {
899 Safe_PyObjectPtr v1 = iter1->next();
900 Safe_PyObjectPtr v2 = iter2->next();
901 if (v1 && v2) {
902 if (Py_EnterRecursiveCall(" in assert_same_structure")) {
903 return false;
904 }
905 bool no_internal_errors = AssertSameStructureHelper(
906 v1.get(), v2.get(), check_types, error_msg, is_type_error,
907 is_sequence_helper, value_iterator_getter,
908 check_composite_tensor_type_spec);
909 Py_LeaveRecursiveCall();
910 if (!no_internal_errors) return false;
911 if (!error_msg->empty()) return true;
912 } else if (!v1 && !v2) {
913 // Done with all recursive calls. Structure matched.
914 return true;
915 } else {
916 *is_type_error = false;
917 *error_msg = tensorflow::strings::StrCat(
918 "The two structures don't have the same number of elements. ",
919 "First structure: ", PyObjectToString(o1),
920 ". Second structure: ", PyObjectToString(o2));
921 return true;
922 }
923 }
924 }
925
926 } // namespace
927
IsSequence(PyObject * o)928 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
IsMapping(PyObject * o)929 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)930 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)931 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)932 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)933 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)934 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)935 bool IsResourceVariable(PyObject* o) {
936 return IsResourceVariableHelper(o) == 1;
937 }
IsOwnedIterator(PyObject * o)938 bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; }
IsVariable(PyObject * o)939 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)940 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)941 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
942
IsTuple(PyObject * o)943 bool IsTuple(PyObject* o) {
944 tensorflow::Safe_PyObjectPtr wrapped;
945 if (IsObjectProxy(o)) {
946 wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
947 o = wrapped.get();
948 }
949 return PyTuple_Check(o);
950 }
951
952 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
953 // and while we're at it give them consistent behavior by making sure the
954 // returned value is a list.
955 //
956 // As with PyMapping_Keys, returns a new reference.
957 //
958 // On failure, returns nullptr.
MappingKeys(PyObject * o)959 PyObject* MappingKeys(PyObject* o) {
960 #if PY_MAJOR_VERSION >= 3
961 return PyMapping_Keys(o);
962 #else
963 static char key_method_name[] = "keys";
964 Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
965 if (PyErr_Occurred() || raw_result.get() == nullptr) {
966 return nullptr;
967 }
968 return PySequence_Fast(
969 raw_result.get(),
970 "The '.keys()' method of a custom mapping returned a non-sequence.");
971 #endif
972 }
973
Flatten(PyObject * nested,bool expand_composites)974 PyObject* Flatten(PyObject* nested, bool expand_composites) {
975 PyObject* list = PyList_New(0);
976 const std::function<int(PyObject*)>& is_sequence_helper =
977 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
978 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
979 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
980 if (FlattenHelper(nested, list, is_sequence_helper, get_value_iterator)) {
981 return list;
982 } else {
983 Py_DECREF(list);
984 return nullptr;
985 }
986 }
987
IsSequenceOrComposite(PyObject * o)988 bool IsSequenceOrComposite(PyObject* o) {
989 return IsSequenceOrCompositeHelper(o) == 1;
990 }
991
IsCompositeTensor(PyObject * o)992 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
993
IsTypeSpec(PyObject * o)994 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
995
IsSequenceForData(PyObject * o)996 bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
997
FlattenForData(PyObject * nested)998 PyObject* FlattenForData(PyObject* nested) {
999 PyObject* list = PyList_New(0);
1000 if (FlattenHelper(nested, list, IsSequenceForDataHelper,
1001 GetValueIteratorForData)) {
1002 return list;
1003 } else {
1004 Py_DECREF(list);
1005 return nullptr;
1006 }
1007 }
1008
IsNamedtuple(PyObject * o,bool strict)1009 PyObject* IsNamedtuple(PyObject* o, bool strict) {
1010 // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1011 // require some unwrapping if we want to treat them like the objects they're
1012 // wrapping.
1013 tensorflow::Safe_PyObjectPtr o_wrapped;
1014 if (IsObjectProxy(o)) {
1015 o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1016 o = o_wrapped.get();
1017 }
1018
1019 // Must be subclass of tuple
1020 if (!PyTuple_Check(o)) {
1021 Py_RETURN_FALSE;
1022 }
1023
1024 // If strict, o.__class__.__base__ must be tuple
1025 if (strict) {
1026 PyObject* klass = PyObject_GetAttrString(o, "__class__");
1027 if (klass == nullptr) return nullptr;
1028 PyObject* base = PyObject_GetAttrString(klass, "__base__");
1029 Py_DECREF(klass);
1030 if (base == nullptr) return nullptr;
1031
1032 const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1033 // built-in object types are singletons
1034 bool tuple_base = base_type == &PyTuple_Type;
1035 Py_DECREF(base);
1036 if (!tuple_base) {
1037 Py_RETURN_FALSE;
1038 }
1039 }
1040
1041 // o must have attribute '_fields' and every element in
1042 // '_fields' must be a string.
1043 int has_fields = PyObject_HasAttrString(o, "_fields");
1044 if (!has_fields) {
1045 Py_RETURN_FALSE;
1046 }
1047
1048 Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1049 int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1050 if (is_instance == 0) {
1051 Py_RETURN_FALSE;
1052 } else if (is_instance == -1) {
1053 return nullptr;
1054 }
1055
1056 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1057 const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1058 for (Py_ssize_t i = 0; i < s; ++i) {
1059 // PySequence_Fast_GET_ITEM returns borrowed ref
1060 PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1061 if (!IsString(elem)) {
1062 Py_RETURN_FALSE;
1063 }
1064 }
1065
1066 Py_RETURN_TRUE;
1067 }
1068
SameNamedtuples(PyObject * o1,PyObject * o2)1069 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1070 Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1071 Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1072 if (f1 == nullptr || f2 == nullptr) {
1073 PyErr_SetString(
1074 PyExc_RuntimeError,
1075 "Expected namedtuple-like objects (that have _fields attr)");
1076 return nullptr;
1077 }
1078
1079 if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1080 Py_RETURN_FALSE;
1081 }
1082
1083 if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1084 Py_RETURN_TRUE;
1085 } else {
1086 Py_RETURN_FALSE;
1087 }
1088 }
1089
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1090 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1091 bool expand_composites) {
1092 const std::function<int(PyObject*)>& is_sequence_helper =
1093 expand_composites ? IsSequenceOrCompositeHelper : IsSequenceHelper;
1094 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1095 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1096 const bool check_composite_tensor_type_spec = expand_composites;
1097 string error_msg;
1098 bool is_type_error = false;
1099 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1100 is_sequence_helper, get_value_iterator,
1101 check_composite_tensor_type_spec);
1102 if (PyErr_Occurred()) {
1103 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1104 // from custom mappings).
1105 return nullptr;
1106 }
1107 if (!error_msg.empty()) {
1108 PyErr_SetString(
1109 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1110 tensorflow::strings::StrCat(
1111 "The two structures don't have the same nested structure.\n\n",
1112 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1113 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1114 .c_str());
1115 return nullptr;
1116 }
1117 Py_RETURN_NONE;
1118 }
1119
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1120 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1121 bool check_types) {
1122 string error_msg;
1123 bool is_type_error = false;
1124 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1125 IsSequenceForDataHelper, GetValueIterator, false);
1126 if (PyErr_Occurred()) {
1127 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1128 // from custom mappings).
1129 return nullptr;
1130 }
1131 if (!error_msg.empty()) {
1132 PyErr_SetString(
1133 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1134 tensorflow::strings::StrCat(
1135 "The two structures don't have the same nested structure.\n\n",
1136 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1137 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1138 .c_str());
1139 return nullptr;
1140 }
1141 Py_RETURN_NONE;
1142 }
1143
1144 } // namespace swig
1145 } // namespace tensorflow
1146