1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/nest.h"
16
17 #include <utility>
18
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/stringpiece.h"
22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23 #include "tensorflow/python/util/util.h"
24
25 namespace tensorflow {
26
27 namespace {
28
29 // Gets a string representation of the input object.
30 //
31 // Args:
32 // o: a python object.
33 // length: If set to negative, the whole string is returned. Otherwise, the
34 // string gets clipped to 'length' in size.
35 //
36 // Returns:
37 // A string representation.
PyObject_ToString(PyObject * o,int length=-1)38 std::string PyObject_ToString(PyObject* o, int length = -1) {
39 auto str_o = make_safe(PyObject_Str(o));
40 std::string str = PyUnicode_AsUTF8(str_o.get());
41 if (length < 0 || str.size() <= length) {
42 return str;
43 }
44 tensorflow::StringPiece str_piece(str);
45 return tensorflow::strings::StrCat(str_piece.substr(length), "...");
46 }
47
48 // Gets a list of keys from a dict or mapping type object.
49 //
50 // Args:
51 // o: a dictionary or mapping type object.
52 //
53 // Returns:
54 // A new reference to a list.
55 //
56 // Raises:
57 // TypeError: if `o` is not a dict or mapping type object.
GetKeysFromDictOrMapping(PyObject * o)58 PyObject* GetKeysFromDictOrMapping(PyObject* o) {
59 if (PyDict_Check(o)) {
60 return PyDict_Keys(o);
61 } else if (PyMapping_Check(o)) {
62 return PyMapping_Keys(o);
63 } else {
64 auto* o_type = Py_TYPE(o);
65 PyErr_SetString(
66 PyExc_TypeError,
67 tensorflow::strings::StrCat(
68 "Expecting a type compatible with dict or mapping, got '",
69 o_type->tp_name, "'")
70 .c_str());
71 return nullptr;
72 }
73 }
74
75 } // namespace
76
FlattenDictItems(PyObject * dict)77 PyObject* FlattenDictItems(PyObject* dict) {
78 if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
79 PyErr_SetString(PyExc_TypeError,
80 tensorflow::strings::StrCat(
81 "FlattenDictItems: 'dict' must be a dictionary or ",
82 "collection.Mapping type object, instead of '",
83 Py_TYPE(dict)->tp_name, "'.")
84 .c_str());
85 return nullptr;
86 }
87 PyObject* flat_dictionary = PyDict_New();
88 auto keys = make_safe(GetKeysFromDictOrMapping(dict));
89 for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
90 auto* key = PyList_GetItem(keys.get(), i);
91 // We use a general approach in case 'dict' is a PyMapping type,
92 // but not a PyDict type.
93 auto* value = PyObject_GetItem(dict, key);
94 if (swig::IsSequence(key)) {
95 // The dict might contain list - list pairs.
96 auto flat_keys = make_safe(swig::Flatten(key, false));
97 auto flat_values = make_safe(swig::Flatten(value, false));
98 size_t flat_keys_sz = PyList_Size(flat_keys.get());
99 size_t flat_values_sz = PyList_Size(flat_values.get());
100 if (flat_keys_sz != flat_values_sz) {
101 PyErr_SetString(
102 PyExc_ValueError,
103 tensorflow::strings::StrCat(
104 "Could not flatten dictionary. Key had ", flat_keys_sz,
105 " elements, but value had ", flat_values_sz,
106 " elements. Key: ", PyObject_ToString(flat_keys.get()),
107 ", value: ", PyObject_ToString(flat_values.get()), ".")
108 .c_str());
109 Py_DecRef(flat_dictionary);
110 return nullptr;
111 }
112 for (size_t i = 0; i < flat_keys_sz; ++i) {
113 auto* flat_key = PyList_GetItem(flat_keys.get(), i);
114 auto* flat_value = PyList_GetItem(flat_values.get(), i);
115 if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
116 PyErr_SetString(
117 PyExc_ValueError,
118 tensorflow::strings::StrCat(
119 "Cannot flatten dict because this key is not unique: ",
120 PyObject_ToString(flat_key))
121 .c_str());
122 Py_DecRef(flat_dictionary);
123 return nullptr;
124 }
125 PyDict_SetItem(flat_dictionary, flat_key, flat_value);
126 }
127 } else {
128 if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
129 PyErr_SetString(
130 PyExc_ValueError,
131 tensorflow::strings::StrCat(
132 "Cannot flatten dict because this key is not unique: ",
133 PyObject_ToString(key))
134 .c_str());
135 Py_DecRef(flat_dictionary);
136 return nullptr;
137 }
138 PyDict_SetItem(flat_dictionary, key, value);
139 }
140 // Manually decrease because PyObject_GetItem() returns a new reference.
141 Py_DECREF(value);
142 }
143 return flat_dictionary;
144 }
145
146 } // namespace tensorflow
147