• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18 
19 #include "tensorflow/compiler/xla/python/pytree.h"
20 
21 #include <memory>
22 #include <stdexcept>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/hash/hash.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "pybind11/pybind11.h"
33 #include "pybind11/pytypes.h"
34 #include "pybind11/stl.h"
35 #include "tensorflow/compiler/xla/python/absl_casters.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 
40 namespace py = pybind11;
41 
Singleton()42 /*static*/ PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton() {
43   static auto* registry = []() -> PyTreeTypeRegistry* {
44     auto* registry = new PyTreeTypeRegistry;
45 
46     auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) {
47       py::object type = py::reinterpret_borrow<py::object>(
48           reinterpret_cast<PyObject*>(type_obj));
49       auto registration = absl::make_unique<Registration>();
50       registration->kind = kind;
51       registration->type = type;
52       CHECK(registry->registrations_.emplace(type, std::move(registration))
53                 .second);
54     };
55     add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone);
56     add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple);
57     add_builtin_type(&PyList_Type, PyTreeKind::kList);
58     add_builtin_type(&PyDict_Type, PyTreeKind::kDict);
59     return registry;
60   }();
61   return registry;
62 }
63 
Register(py::object type,py::function to_iterable,py::function from_iterable)64 /*static*/ void PyTreeTypeRegistry::Register(py::object type,
65                                              py::function to_iterable,
66                                              py::function from_iterable) {
67   PyTreeTypeRegistry* registry = Singleton();
68   auto registration = absl::make_unique<Registration>();
69   registration->kind = PyTreeKind::kCustom;
70   registration->type = type;
71   registration->to_iterable = std::move(to_iterable);
72   registration->from_iterable = std::move(from_iterable);
73   auto it = registry->registrations_.emplace(type, std::move(registration));
74   if (!it.second) {
75     throw std::invalid_argument(
76         absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
77                         py::repr(type)));
78   }
79 }
80 
Lookup(py::handle type)81 /*static*/ const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup(
82     py::handle type) {
83   PyTreeTypeRegistry* registry = Singleton();
84   auto it = registry->registrations_.find(type);
85   return it == registry->registrations_.end() ? nullptr : it->second.get();
86 }
87 
operator ==(const PyTreeDef & other) const88 bool PyTreeDef::operator==(const PyTreeDef& other) const {
89   if (traversal_.size() != other.traversal_.size()) {
90     return false;
91   }
92   for (size_t i = 0; i < traversal_.size(); ++i) {
93     const Node& a = traversal_[i];
94     const Node& b = other.traversal_[i];
95     if (a.kind != b.kind || a.arity != b.arity ||
96         (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
97         a.custom != b.custom) {
98       return false;
99     }
100     if (a.node_data && a.node_data.not_equal(b.node_data)) {
101       return false;
102     }
103     // We don't need to test equality of num_leaves and num_nodes since they
104     // are derivable from the other node data.
105   }
106   return true;
107 }
108 
GetKind(const py::handle & obj,PyTreeTypeRegistry::Registration const ** custom)109 /*static*/ PyTreeKind PyTreeDef::GetKind(
110     const py::handle& obj, PyTreeTypeRegistry::Registration const** custom) {
111   const PyTreeTypeRegistry::Registration* registration =
112       PyTreeTypeRegistry::Lookup(obj.get_type());
113   if (registration) {
114     *custom = registration;
115     return registration->kind;
116   } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
117     // We can only identify namedtuples heuristically, here by the presence of
118     // a _fields attribute.
119     return PyTreeKind::kNamedTuple;
120   } else {
121     return PyTreeKind::kLeaf;
122   }
123 }
124 
125 template <typename T>
FlattenIntoImpl(py::handle handle,T & leaves,const absl::optional<py::function> & leaf_predicate)126 void PyTreeDef::FlattenIntoImpl(
127     py::handle handle, T& leaves,
128     const absl::optional<py::function>& leaf_predicate) {
129   Node node;
130   int start_num_nodes = traversal_.size();
131   int start_num_leaves = leaves.size();
132   if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
133     leaves.push_back(py::reinterpret_borrow<py::object>(handle));
134   } else {
135     node.kind = GetKind(handle, &node.custom);
136     auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
137       FlattenInto(child, leaves, leaf_predicate);
138     };
139     switch (node.kind) {
140       case PyTreeKind::kNone:
141         // Nothing to do.
142         break;
143       case PyTreeKind::kTuple: {
144         node.arity = PyTuple_GET_SIZE(handle.ptr());
145         for (int i = 0; i < node.arity; ++i) {
146           recurse(PyTuple_GET_ITEM(handle.ptr(), i));
147         }
148         break;
149       }
150       case PyTreeKind::kList: {
151         node.arity = PyList_GET_SIZE(handle.ptr());
152         for (int i = 0; i < node.arity; ++i) {
153           recurse(PyList_GET_ITEM(handle.ptr(), i));
154         }
155         break;
156       }
157       case PyTreeKind::kDict: {
158         py::dict dict = py::reinterpret_borrow<py::dict>(handle);
159         py::list keys =
160             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
161         if (PyList_Sort(keys.ptr())) {
162           throw std::runtime_error("Dictionary key sort failed.");
163         }
164         for (py::handle key : keys) {
165           recurse(dict[key]);
166         }
167         node.arity = dict.size();
168         node.node_data = std::move(keys);
169         break;
170       }
171       case PyTreeKind::kCustom: {
172         py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
173         if (out.size() != 2) {
174           throw std::runtime_error(
175               "PyTree custom to_iterable function should return a pair");
176         }
177         node.node_data = out[1];
178         node.arity = 0;
179         for (py::handle entry : py::cast<py::iterable>(out[0])) {
180           ++node.arity;
181           recurse(entry);
182         }
183         break;
184       }
185       case PyTreeKind::kNamedTuple: {
186         py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
187         node.arity = tuple.size();
188         node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
189         for (py::handle entry : tuple) {
190           recurse(entry);
191         }
192         break;
193       }
194       default:
195         DCHECK(node.kind == PyTreeKind::kLeaf);
196         leaves.push_back(py::reinterpret_borrow<py::object>(handle));
197     }
198   }
199   node.num_nodes = traversal_.size() - start_num_nodes + 1;
200   node.num_leaves = leaves.size() - start_num_leaves;
201   traversal_.push_back(std::move(node));
202 }
203 
FlattenInto(py::handle handle,absl::InlinedVector<py::object,2> & leaves,absl::optional<py::function> leaf_predicate)204 void PyTreeDef::FlattenInto(py::handle handle,
205                             absl::InlinedVector<py::object, 2>& leaves,
206                             absl::optional<py::function> leaf_predicate) {
207   FlattenIntoImpl(handle, leaves, leaf_predicate);
208 }
209 
FlattenInto(py::handle handle,std::vector<py::object> & leaves,absl::optional<py::function> leaf_predicate)210 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
211                             absl::optional<py::function> leaf_predicate) {
212   FlattenIntoImpl(handle, leaves, leaf_predicate);
213 }
214 
215 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,absl::optional<py::function> leaf_predicate)216 PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
217   std::vector<py::object> leaves;
218   auto tree = absl::make_unique<PyTreeDef>();
219   tree->FlattenInto(x, leaves, leaf_predicate);
220   return std::make_pair(std::move(leaves), std::move(tree));
221 }
222 
AllLeaves(const py::iterable & x)223 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
224   const PyTreeTypeRegistry::Registration* custom;
225   for (const py::handle& h : x) {
226     if (GetKind(h, &custom) != PyTreeKind::kLeaf) return false;
227   }
228   return true;
229 }
230 
231 template <typename T>
UnflattenImpl(T leaves) const232 py::object PyTreeDef::UnflattenImpl(T leaves) const {
233   absl::InlinedVector<py::object, 4> agenda;
234   auto it = leaves.begin();
235   int leaf_count = 0;
236   for (const Node& node : traversal_) {
237     if (agenda.size() < node.arity) {
238       throw std::logic_error("Too few elements for TreeDef node.");
239     }
240     switch (node.kind) {
241       case PyTreeKind::kLeaf:
242         if (it == leaves.end()) {
243           throw std::invalid_argument(absl::StrFormat(
244               "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
245               leaf_count));
246         }
247         agenda.push_back(py::reinterpret_borrow<py::object>(*it));
248         ++it;
249         ++leaf_count;
250         break;
251 
252       case PyTreeKind::kNone:
253       case PyTreeKind::kTuple:
254       case PyTreeKind::kNamedTuple:
255       case PyTreeKind::kList:
256       case PyTreeKind::kDict:
257       case PyTreeKind::kCustom: {
258         const int size = agenda.size();
259         absl::Span<py::object> span;
260         if (node.arity > 0) {
261           span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
262         }
263         py::object o = MakeNode(node, span);
264         agenda.resize(size - node.arity);
265         agenda.push_back(o);
266         break;
267       }
268     }
269   }
270   if (it != leaves.end()) {
271     throw std::invalid_argument(absl::StrFormat(
272         "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
273   }
274   if (agenda.size() != 1) {
275     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
276   }
277   return std::move(agenda.back());
278 }
279 
Unflatten(py::iterable leaves) const280 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
281   return UnflattenImpl(leaves);
282 }
283 
Unflatten(absl::Span<const py::object> leaves) const284 py::object PyTreeDef::Unflatten(absl::Span<const py::object> leaves) const {
285   return UnflattenImpl(leaves);
286 }
287 
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)288 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
289                                           absl::Span<py::object> children) {
290   if (children.size() != node.arity) {
291     throw std::logic_error("Node arity mismatch.");
292   }
293   switch (node.kind) {
294     case PyTreeKind::kLeaf:
295       throw std::logic_error("MakeNode not implemented for leaves.");
296 
297     case PyTreeKind::kNone:
298       return py::none();
299 
300     case PyTreeKind::kTuple:
301     case PyTreeKind::kNamedTuple: {
302       py::tuple tuple(node.arity);
303       for (int i = 0; i < node.arity; ++i) {
304         tuple[i] = std::move(children[i]);
305       }
306       if (node.kind == PyTreeKind::kNamedTuple) {
307         return node.node_data(*tuple);
308       } else {
309         return std::move(tuple);
310       }
311     }
312 
313     case PyTreeKind::kList: {
314       py::list list(node.arity);
315       for (int i = 0; i < node.arity; ++i) {
316         list[i] = std::move(children[i]);
317       }
318       return std::move(list);
319     }
320 
321     case PyTreeKind::kDict: {
322       py::dict dict;
323       py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
324       for (int i = 0; i < node.arity; ++i) {
325         dict[keys[i]] = std::move(children[i]);
326       }
327       return std::move(dict);
328       break;
329     }
330     case PyTreeKind::kCustom: {
331       py::tuple tuple(node.arity);
332       for (int i = 0; i < node.arity; ++i) {
333         tuple[i] = std::move(children[i]);
334       }
335       return node.custom->from_iterable(node.node_data, tuple);
336     }
337   }
338   throw std::logic_error("Unreachable code.");
339 }
340 
FlattenUpTo(py::handle xs) const341 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
342   py::list leaves(num_leaves());
343   std::vector<py::object> agenda;
344   agenda.push_back(py::reinterpret_borrow<py::object>(xs));
345   auto it = traversal_.rbegin();
346   int leaf = num_leaves() - 1;
347   while (!agenda.empty()) {
348     if (it == traversal_.rend()) {
349       throw std::invalid_argument(absl::StrFormat(
350           "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
351     }
352     const Node& node = *it;
353     py::object object = agenda.back();
354     agenda.pop_back();
355     ++it;
356 
357     switch (node.kind) {
358       case PyTreeKind::kLeaf:
359         if (leaf < 0) {
360           throw std::logic_error("Leaf count mismatch.");
361         }
362         leaves[leaf] = py::reinterpret_borrow<py::object>(object);
363         --leaf;
364         break;
365 
366       case PyTreeKind::kNone:
367         break;
368 
369       case PyTreeKind::kTuple: {
370         if (!PyTuple_CheckExact(object.ptr())) {
371           throw std::invalid_argument(
372               absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
373         }
374         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
375         if (tuple.size() != node.arity) {
376           throw std::invalid_argument(
377               absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
378                               tuple.size(), node.arity, py::repr(object)));
379         }
380         for (py::handle entry : tuple) {
381           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
382         }
383         break;
384       }
385 
386       case PyTreeKind::kList: {
387         if (!PyList_CheckExact(object.ptr())) {
388           throw std::invalid_argument(
389               absl::StrFormat("Expected list, got %s.", py::repr(object)));
390         }
391         py::list list = py::reinterpret_borrow<py::list>(object);
392         if (list.size() != node.arity) {
393           throw std::invalid_argument(
394               absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
395                               list.size(), node.arity, py::repr(object)));
396         }
397         for (py::handle entry : list) {
398           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
399         }
400         break;
401       }
402 
403       case PyTreeKind::kDict: {
404         if (!PyDict_CheckExact(object.ptr())) {
405           throw std::invalid_argument(
406               absl::StrFormat("Expected dict, got %s.", py::repr(object)));
407         }
408         py::dict dict = py::reinterpret_borrow<py::dict>(object);
409         py::list keys =
410             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
411         if (PyList_Sort(keys.ptr())) {
412           throw std::runtime_error("Dictionary key sort failed.");
413         }
414         if (keys.not_equal(node.node_data)) {
415           throw std::invalid_argument(
416               absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
417                               py::repr(node.node_data), py::repr(object)));
418         }
419         for (py::handle key : keys) {
420           agenda.push_back(dict[key]);
421         }
422         break;
423       }
424 
425       case PyTreeKind::kNamedTuple: {
426         if (!py::isinstance<py::tuple>(object) ||
427             !py::hasattr(object, "_fields")) {
428           throw std::invalid_argument(absl::StrFormat(
429               "Expected named tuple, got %s.", py::repr(object)));
430         }
431         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
432         if (tuple.size() != node.arity) {
433           throw std::invalid_argument(absl::StrFormat(
434               "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
435               node.arity, py::repr(object)));
436         }
437         if (tuple.get_type().not_equal(node.node_data)) {
438           throw std::invalid_argument(absl::StrFormat(
439               "Named tuple type mismatch: expected type: %s, tuple: %s.",
440               py::repr(node.node_data), py::repr(object)));
441         }
442         for (py::handle entry : tuple) {
443           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
444         }
445         break;
446       }
447 
448       case PyTreeKind::kCustom: {
449         auto* registration = PyTreeTypeRegistry::Lookup(object.get_type());
450         if (registration != node.custom) {
451           throw std::invalid_argument(absl::StrFormat(
452               "Custom node type mismatch: expected type: %s, value: %s.",
453               py::repr(node.custom->type), py::repr(object)));
454         }
455         py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
456         if (out.size() != 2) {
457           throw std::runtime_error(
458               "PyTree custom to_iterable function should return a pair");
459         }
460         if (node.node_data.not_equal(out[1])) {
461           throw std::invalid_argument(absl::StrFormat(
462               "Mismatch custom node data: %s != %s; value: %s.",
463               py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
464         }
465         int arity = 0;
466         for (py::handle entry : py::cast<py::iterable>(out[0])) {
467           ++arity;
468           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
469         }
470         if (arity != node.arity) {
471           throw std::invalid_argument(absl::StrFormat(
472               "Custom type arity mismatch: %d != %d; value: %s.", arity,
473               node.arity, py::repr(object)));
474         }
475         break;
476       }
477     }
478   }
479   if (it != traversal_.rend() || leaf != -1) {
480     throw std::invalid_argument(absl::StrFormat(
481         "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
482   }
483   return leaves;
484 }
485 
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const486 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
487                            py::iterable leaves) const {
488   std::vector<py::object> agenda;
489   auto it = leaves.begin();
490   for (const Node& node : traversal_) {
491     switch (node.kind) {
492       case PyTreeKind::kLeaf: {
493         if (it == leaves.end()) {
494           throw std::invalid_argument("Too few leaves for PyTreeDef");
495         }
496 
497         py::object leaf = py::reinterpret_borrow<py::object>(*it);
498         agenda.push_back(f_leaf.is_none() ? std::move(leaf)
499                                           : f_leaf(std::move(leaf)));
500         ++it;
501         break;
502       }
503 
504       case PyTreeKind::kNone:
505       case PyTreeKind::kTuple:
506       case PyTreeKind::kNamedTuple:
507       case PyTreeKind::kList:
508       case PyTreeKind::kDict:
509       case PyTreeKind::kCustom: {
510         if (agenda.size() < node.arity) {
511           throw std::logic_error("Too few elements for custom type.");
512         }
513         py::tuple tuple(node.arity);
514         for (int i = node.arity - 1; i >= 0; --i) {
515           tuple[i] = agenda.back();
516           agenda.pop_back();
517         }
518         agenda.push_back(f_node(tuple));
519       }
520     }
521   }
522   if (it != leaves.end()) {
523     throw std::invalid_argument("Too many leaves for PyTreeDef");
524   }
525   if (agenda.size() != 1) {
526     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
527   }
528   return std::move(agenda.back());
529 }
530 
FromIterableTreeHelper(py::handle xs,absl::InlinedVector<PyTreeDef::Node,1>::const_reverse_iterator * it) const531 py::object PyTreeDef::FromIterableTreeHelper(
532     py::handle xs,
533     absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it) const {
534   if (*it == traversal_.rend()) {
535     throw std::invalid_argument("Tree structures did not match.");
536   }
537   const Node& node = **it;
538   ++*it;
539   if (node.kind == PyTreeKind::kLeaf) {
540     return py::reinterpret_borrow<py::object>(xs);
541   }
542   py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
543   std::vector<py::object> ys;
544   ys.reserve(node.arity);
545   for (py::handle x : iterable) {
546     ys.push_back(py::reinterpret_borrow<py::object>(x));
547   }
548   if (ys.size() != node.arity) {
549     throw std::invalid_argument("Arity mismatch between trees");
550   }
551   for (int j = node.arity - 1; j >= 0; --j) {
552     ys[j] = FromIterableTreeHelper(ys[j], it);
553   }
554 
555   return MakeNode(node, absl::MakeSpan(ys));
556 }
557 
FromIterableTree(py::handle xs) const558 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
559   auto it = traversal_.rbegin();
560   py::object out = FromIterableTreeHelper(xs, &it);
561   if (it != traversal_.rend()) {
562     throw std::invalid_argument("Tree structures did not match.");
563   }
564   return out;
565 }
566 
Compose(const PyTreeDef & inner) const567 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
568   auto out = absl::make_unique<PyTreeDef>();
569   for (const Node& n : traversal_) {
570     if (n.kind == PyTreeKind::kLeaf) {
571       absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
572     } else {
573       out->traversal_.push_back(n);
574     }
575   }
576   const auto& root = traversal_.back();
577   const auto& inner_root = inner.traversal_.back();
578   // TODO(tomhennigan): This should update all nodes in the traversal.
579   auto& out_root = out->traversal_.back();
580   out_root.num_nodes = (root.num_nodes - root.num_leaves) +
581                        (inner_root.num_nodes * root.num_leaves);
582   out_root.num_leaves *= inner_root.num_leaves;
583   return out;
584 }
585 
Tuple(const std::vector<PyTreeDef> & defs)586 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
587     const std::vector<PyTreeDef>& defs) {
588   auto out = absl::make_unique<PyTreeDef>();
589   int num_leaves = 0;
590   for (const PyTreeDef& def : defs) {
591     absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
592     num_leaves += def.num_leaves();
593   }
594   Node node;
595   node.kind = PyTreeKind::kTuple;
596   node.arity = defs.size();
597   node.num_leaves = num_leaves;
598   node.num_nodes = out->traversal_.size() + 1;
599   out->traversal_.push_back(node);
600   return out;
601 }
602 
Children() const603 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
604   std::vector<std::unique_ptr<PyTreeDef>> children;
605   if (traversal_.empty()) {
606     return children;
607   }
608   Node const& root = traversal_.back();
609   children.resize(root.arity);
610   int pos = traversal_.size() - 1;
611   for (int i = root.arity - 1; i >= 0; --i) {
612     children[i] = absl::make_unique<PyTreeDef>();
613     const Node& node = traversal_.at(pos - 1);
614     if (pos < node.num_nodes) {
615       throw std::logic_error("children() walked off start of array");
616     }
617     std::copy(traversal_.begin() + pos - node.num_nodes,
618               traversal_.begin() + pos,
619               std::back_inserter(children[i]->traversal_));
620     pos -= node.num_nodes;
621   }
622   if (pos != 0) {
623     throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
624   }
625   return children;
626 }
627 
ToString() const628 std::string PyTreeDef::ToString() const {
629   std::vector<std::string> agenda;
630   for (const Node& node : traversal_) {
631     if (agenda.size() < node.arity) {
632       throw std::logic_error("Too few elements for container.");
633     }
634 
635     std::string children =
636         absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", ");
637     std::string representation;
638     switch (node.kind) {
639       case PyTreeKind::kLeaf:
640         agenda.push_back("*");
641         continue;
642       case PyTreeKind::kNone:
643         representation = "None";
644         break;
645       case PyTreeKind::kTuple:
646         // Tuples with only one element must have a trailing comma.
647         if (node.arity == 1) children += ",";
648         representation = absl::StrCat("(", children, ")");
649         break;
650       case PyTreeKind::kList:
651         representation = absl::StrCat("[", children, "]");
652         break;
653       case PyTreeKind::kDict: {
654         if (py::len(node.node_data) != node.arity) {
655           throw std::logic_error("Number of keys and entries does not match.");
656         }
657         std::string separator = "{";
658         auto child_iter = agenda.end() - node.arity;
659         for (const py::handle& key : node.node_data) {
660           absl::StrAppendFormat(&representation, "%s%s: %s", separator,
661                                 py::repr(key), *child_iter);
662           child_iter++;
663           separator = ", ";
664         }
665         representation += "}";
666         break;
667       }
668 
669       case PyTreeKind::kNamedTuple:
670       case PyTreeKind::kCustom: {
671         std::string kind;
672         if (node.kind == PyTreeKind::kNamedTuple) {
673           kind = "namedtuple";
674         } else {
675           kind = static_cast<std::string>(py::str(node.custom->type));
676         }
677 
678         std::string data;
679         if (node.node_data) {
680           data = absl::StrFormat("[%s]", py::str(node.node_data));
681         }
682         representation =
683             absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children);
684         break;
685       }
686     }
687     agenda.erase(agenda.end() - node.arity, agenda.end());
688     agenda.push_back(std::move(representation));
689   }
690   if (agenda.size() != 1) {
691     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
692   }
693   return absl::StrCat("PyTreeDef(", agenda.back(), ")");
694 }
695 
BuildPytreeSubmodule(py::module & m)696 void BuildPytreeSubmodule(py::module& m) {
697   py::module pytree = m.def_submodule("pytree", "Python tree library");
698   pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
699              py::arg("leaf_predicate") = absl::nullopt);
700   pytree.def("tuple", &PyTreeDef::Tuple);
701   pytree.def("all_leaves", &PyTreeDef::AllLeaves);
702 
703   py::class_<PyTreeDef>(m, "PyTreeDef")
704       .def("unflatten",
705            static_cast<pybind11::object (PyTreeDef::*)(
706                pybind11::iterable leaves) const>(&PyTreeDef::Unflatten))
707       .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
708       .def("compose", &PyTreeDef::Compose)
709       .def("walk", &PyTreeDef::Walk)
710       .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
711       .def("children", &PyTreeDef::Children)
712       .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
713       .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
714       .def("__repr__", &PyTreeDef::ToString)
715       .def("__eq__",
716            [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
717       .def("__ne__",
718            [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
719       .def("__hash__",
720            [](const PyTreeDef& t) { return absl::Hash<PyTreeDef>()(t); });
721 
722   pytree.def("register_node", [](py::object type, py::function to_iterable,
723                                  py::function from_iterable) {
724     return PyTreeTypeRegistry::Register(type, to_iterable, from_iterable);
725   });
726 }
727 
728 }  // namespace xla
729