• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/nest.h"
16 
17 #include <utility>
18 
19 #include "tensorflow/core/lib/strings/strcat.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/stringpiece.h"
22 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
23 #include "tensorflow/python/util/util.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
29 // Gets a string representation of the input object.
30 //
31 // Args:
32 //   o: a python object.
33 //   length: If set to negative, the whole string is returned. Otherwise, the
34 //       string gets clipped to 'length' in size.
35 //
36 // Returns:
37 //   A string representation.
PyObject_ToString(PyObject * o,int length=-1)38 std::string PyObject_ToString(PyObject* o, int length = -1) {
39   auto str_o = make_safe(PyObject_Str(o));
40   std::string str = PyUnicode_AsUTF8(str_o.get());
41   if (length < 0 || str.size() <= length) {
42     return str;
43   }
44   tensorflow::StringPiece str_piece(str);
45   return tensorflow::strings::StrCat(str_piece.substr(length), "...");
46 }
47 
48 // Gets a list of keys from a dict or mapping type object.
49 //
50 // Args:
51 //   o: a dictionary or mapping type object.
52 //
53 // Returns:
54 //   A new reference to a list.
55 //
56 // Raises:
57 //   TypeError: if `o` is not a dict or mapping type object.
GetKeysFromDictOrMapping(PyObject * o)58 PyObject* GetKeysFromDictOrMapping(PyObject* o) {
59   if (PyDict_Check(o)) {
60     return PyDict_Keys(o);
61   } else if (PyMapping_Check(o)) {
62     return PyMapping_Keys(o);
63   } else {
64     auto* o_type = Py_TYPE(o);
65     PyErr_SetString(
66         PyExc_TypeError,
67         tensorflow::strings::StrCat(
68             "Expecting a type compatible with dict or mapping, got '",
69             o_type->tp_name, "'")
70             .c_str());
71     return nullptr;
72   }
73 }
74 
75 }  // namespace
76 
FlattenDictItems(PyObject * dict)77 PyObject* FlattenDictItems(PyObject* dict) {
78   if (!PyDict_Check(dict) && !swig::IsMapping(dict)) {
79     PyErr_SetString(PyExc_TypeError,
80                     tensorflow::strings::StrCat(
81                         "FlattenDictItems: 'dict' must be a dictionary or ",
82                         "collection.Mapping type object, instead of '",
83                         Py_TYPE(dict)->tp_name, "'.")
84                         .c_str());
85     return nullptr;
86   }
87   PyObject* flat_dictionary = PyDict_New();
88   auto keys = make_safe(GetKeysFromDictOrMapping(dict));
89   for (size_t i = 0; i < PyList_Size(keys.get()); ++i) {
90     auto* key = PyList_GetItem(keys.get(), i);
91     // We use a general approach in case 'dict' is a PyMapping type,
92     // but not a PyDict type.
93     auto* value = PyObject_GetItem(dict, key);
94     if (swig::IsSequence(key)) {
95       // The dict might contain list - list pairs.
96       auto flat_keys = make_safe(swig::Flatten(key, false));
97       auto flat_values = make_safe(swig::Flatten(value, false));
98       size_t flat_keys_sz = PyList_Size(flat_keys.get());
99       size_t flat_values_sz = PyList_Size(flat_values.get());
100       if (flat_keys_sz != flat_values_sz) {
101         PyErr_SetString(
102             PyExc_ValueError,
103             tensorflow::strings::StrCat(
104                 "Could not flatten dictionary. Key had ", flat_keys_sz,
105                 " elements, but value had ", flat_values_sz,
106                 " elements. Key: ", PyObject_ToString(flat_keys.get()),
107                 ", value: ", PyObject_ToString(flat_values.get()), ".")
108                 .c_str());
109         Py_DecRef(flat_dictionary);
110         return nullptr;
111       }
112       for (size_t i = 0; i < flat_keys_sz; ++i) {
113         auto* flat_key = PyList_GetItem(flat_keys.get(), i);
114         auto* flat_value = PyList_GetItem(flat_values.get(), i);
115         if (PyDict_GetItem(flat_dictionary, flat_key) != nullptr) {
116           PyErr_SetString(
117               PyExc_ValueError,
118               tensorflow::strings::StrCat(
119                   "Cannot flatten dict because this key is not unique: ",
120                   PyObject_ToString(flat_key))
121                   .c_str());
122           Py_DecRef(flat_dictionary);
123           return nullptr;
124         }
125         PyDict_SetItem(flat_dictionary, flat_key, flat_value);
126       }
127     } else {
128       if (PyDict_GetItem(flat_dictionary, key) != nullptr) {
129         PyErr_SetString(
130             PyExc_ValueError,
131             tensorflow::strings::StrCat(
132                 "Cannot flatten dict because this key is not unique: ",
133                 PyObject_ToString(key))
134                 .c_str());
135         Py_DecRef(flat_dictionary);
136         return nullptr;
137       }
138       PyDict_SetItem(flat_dictionary, key, value);
139     }
140     // Manually decrease because PyObject_GetItem() returns a new reference.
141     Py_DECREF(value);
142   }
143   return flat_dictionary;
144 }
145 
146 }  // namespace tensorflow
147