• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18 
19 #include "tensorflow/compiler/xla/python/pytree.h"
20 
21 #include <memory>
22 #include <stdexcept>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/hash/hash.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "pybind11/pybind11.h"
33 #include "pybind11/pytypes.h"
34 #include "pybind11/stl.h"
35 #include "tensorflow/compiler/xla/python/absl_casters.h"
36 
37 namespace xla {
38 
39 namespace py = pybind11;
40 
Singleton()41 /*static*/ CustomNodeRegistry* CustomNodeRegistry::Singleton() {
42   static auto* registry = new CustomNodeRegistry;
43   return registry;
44 }
45 
Register(py::object type,py::function to_iterable,py::function from_iterable)46 /*static*/ void CustomNodeRegistry::Register(py::object type,
47                                              py::function to_iterable,
48                                              py::function from_iterable) {
49   CustomNodeRegistry* registry = Singleton();
50   auto registration = absl::make_unique<Registration>();
51   registration->type = type;
52   registration->to_iterable = std::move(to_iterable);
53   registration->from_iterable = std::move(from_iterable);
54   auto it = registry->registrations_.emplace(type, std::move(registration));
55   if (!it.second) {
56     throw std::invalid_argument(
57         absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
58                         py::repr(type)));
59   }
60 }
61 
Lookup(py::handle type)62 /*static*/ const CustomNodeRegistry::Registration* CustomNodeRegistry::Lookup(
63     py::handle type) {
64   CustomNodeRegistry* registry = Singleton();
65   auto it =
66       registry->registrations_.find(py::reinterpret_borrow<py::object>(type));
67   return it == registry->registrations_.end() ? nullptr : it->second.get();
68 }
69 
operator ==(const PyTreeDef & other) const70 bool PyTreeDef::operator==(const PyTreeDef& other) const {
71   if (traversal_.size() != other.traversal_.size()) {
72     return false;
73   }
74   for (size_t i = 0; i < traversal_.size(); ++i) {
75     const Node& a = traversal_[i];
76     const Node& b = other.traversal_[i];
77     if (a.kind != b.kind || a.arity != b.arity ||
78         (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
79         a.custom != b.custom) {
80       return false;
81     }
82     if (a.node_data && a.node_data.not_equal(b.node_data)) {
83       return false;
84     }
85     // We don't need to test equality of num_leaves and num_nodes since they
86     // are derivable from the other node data.
87   }
88   return true;
89 }
90 
GetKind(const py::handle & obj,CustomNodeRegistry::Registration const ** custom)91 /*static*/ PyTreeDef::Kind PyTreeDef::GetKind(
92     const py::handle& obj, CustomNodeRegistry::Registration const** custom) {
93   const PyObject* ptr = obj.ptr();
94   if (PyTuple_CheckExact(ptr)) return Kind::kTuple;
95   if (PyList_CheckExact(ptr)) return Kind::kList;
96   if (PyDict_CheckExact(ptr)) return Kind::kDict;
97   if ((*custom = CustomNodeRegistry::Lookup(obj.get_type()))) {
98     return Kind::kCustom;
99   } else if (py::isinstance<py::none>(obj)) {
100     return Kind::kNone;
101   } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
102     // We can only identify namedtuples heuristically, here by the presence of
103     // a _fields attribute.
104     return Kind::kNamedTuple;
105   } else {
106     return Kind::kLeaf;
107   }
108 }
109 
FlattenInto(py::handle handle,std::vector<py::object> & leaves,absl::optional<py::function> leaf_predicate)110 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
111                             absl::optional<py::function> leaf_predicate) {
112   Node node;
113   int start_num_nodes = traversal_.size();
114   int start_num_leaves = leaves.size();
115   if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
116     leaves.push_back(py::reinterpret_borrow<py::object>(handle));
117   } else {
118     node.kind = GetKind(handle, &node.custom);
119     auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
120       FlattenInto(child, leaves, leaf_predicate);
121     };
122     if (node.kind == Kind::kNone) {
123       // Nothing to do.
124     } else if (node.kind == Kind::kTuple) {
125       py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
126       node.arity = tuple.size();
127       for (py::handle entry : tuple) {
128         recurse(entry);
129       }
130     } else if (node.kind == Kind::kList) {
131       py::list list = py::reinterpret_borrow<py::list>(handle);
132       node.arity = list.size();
133       for (py::handle entry : list) {
134         recurse(entry);
135       }
136     } else if (node.kind == Kind::kDict) {
137       py::dict dict = py::reinterpret_borrow<py::dict>(handle);
138       py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
139       if (PyList_Sort(keys.ptr())) {
140         throw std::runtime_error("Dictionary key sort failed.");
141       }
142       for (py::handle key : keys) {
143         recurse(dict[key]);
144       }
145       node.arity = dict.size();
146       node.node_data = std::move(keys);
147     } else if (node.kind == Kind::kCustom) {
148       py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
149       if (out.size() != 2) {
150         throw std::runtime_error(
151             "PyTree custom to_iterable function should return a pair");
152       }
153       node.node_data = out[1];
154       node.arity = 0;
155       for (py::handle entry : py::cast<py::iterable>(out[0])) {
156         ++node.arity;
157         recurse(entry);
158       }
159     } else if (node.kind == Kind::kNamedTuple) {
160       py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
161       node.arity = tuple.size();
162       node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
163       for (py::handle entry : tuple) {
164         recurse(entry);
165       }
166     } else {
167       assert(node.kind == Kind::kLeaf);
168       leaves.push_back(py::reinterpret_borrow<py::object>(handle));
169     }
170   }
171   node.num_nodes = traversal_.size() - start_num_nodes + 1;
172   node.num_leaves = leaves.size() - start_num_leaves;
173   traversal_.push_back(std::move(node));
174 }
175 
176 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,absl::optional<py::function> leaf_predicate)177 PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
178   std::vector<py::object> leaves;
179   auto tree = absl::make_unique<PyTreeDef>();
180   tree->FlattenInto(x, leaves, leaf_predicate);
181   return std::make_pair(std::move(leaves), std::move(tree));
182 }
183 
AllLeaves(const py::iterable & x)184 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
185   const CustomNodeRegistry::Registration* custom;
186   for (const py::handle& h : x) {
187     if (GetKind(h, &custom) != Kind::kLeaf) return false;
188   }
189   return true;
190 }
191 
Unflatten(py::iterable leaves) const192 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
193   std::vector<py::object> agenda;
194   auto it = leaves.begin();
195   int leaf_count = 0;
196   for (const Node& node : traversal_) {
197     if (agenda.size() < node.arity) {
198       throw std::logic_error("Too few elements for TreeDef node.");
199     }
200     switch (node.kind) {
201       case Kind::kLeaf:
202         if (it == leaves.end()) {
203           throw std::invalid_argument(absl::StrFormat(
204               "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
205               leaf_count));
206         }
207         agenda.push_back(py::reinterpret_borrow<py::object>(*it));
208         ++it;
209         ++leaf_count;
210         break;
211 
212       case Kind::kNone:
213       case Kind::kTuple:
214       case Kind::kNamedTuple:
215       case Kind::kList:
216       case Kind::kDict:
217       case Kind::kCustom: {
218         const int size = agenda.size();
219         absl::Span<py::object> span;
220         if (node.arity > 0) {
221           span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
222         }
223         py::object o = MakeNode(node, span);
224         agenda.resize(size - node.arity);
225         agenda.push_back(o);
226         break;
227       }
228     }
229   }
230   if (it != leaves.end()) {
231     throw std::invalid_argument(absl::StrFormat(
232         "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
233   }
234   if (agenda.size() != 1) {
235     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
236   }
237   return std::move(agenda.back());
238 }
239 
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)240 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
241                                           absl::Span<py::object> children) {
242   if (children.size() != node.arity) {
243     throw std::logic_error("Node arity mismatch.");
244   }
245   switch (node.kind) {
246     case Kind::kLeaf:
247       throw std::logic_error("MakeNode not implemented for leaves.");
248 
249     case Kind::kNone:
250       return py::none();
251 
252     case Kind::kTuple:
253     case Kind::kNamedTuple: {
254       py::tuple tuple(node.arity);
255       for (int i = 0; i < node.arity; ++i) {
256         tuple[i] = std::move(children[i]);
257       }
258       if (node.kind == Kind::kNamedTuple) {
259         return node.node_data(*tuple);
260       } else {
261         return std::move(tuple);
262       }
263     }
264 
265     case Kind::kList: {
266       py::list list(node.arity);
267       for (int i = 0; i < node.arity; ++i) {
268         list[i] = std::move(children[i]);
269       }
270       return std::move(list);
271     }
272 
273     case Kind::kDict: {
274       py::dict dict;
275       py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
276       for (int i = 0; i < node.arity; ++i) {
277         dict[keys[i]] = std::move(children[i]);
278       }
279       return std::move(dict);
280       break;
281     }
282     case Kind::kCustom: {
283       py::tuple tuple(node.arity);
284       for (int i = 0; i < node.arity; ++i) {
285         tuple[i] = std::move(children[i]);
286       }
287       return node.custom->from_iterable(node.node_data, tuple);
288     }
289   }
290   throw std::logic_error("Unreachable code.");
291 }
292 
FlattenUpTo(py::handle xs) const293 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
294   py::list leaves(num_leaves());
295   std::vector<py::object> agenda;
296   agenda.push_back(py::reinterpret_borrow<py::object>(xs));
297   auto it = traversal_.rbegin();
298   int leaf = num_leaves() - 1;
299   while (!agenda.empty()) {
300     if (it == traversal_.rend()) {
301       throw std::invalid_argument(absl::StrFormat(
302           "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
303     }
304     const Node& node = *it;
305     py::object object = agenda.back();
306     agenda.pop_back();
307     ++it;
308 
309     switch (node.kind) {
310       case Kind::kLeaf:
311         if (leaf < 0) {
312           throw std::logic_error("Leaf count mismatch.");
313         }
314         leaves[leaf] = py::reinterpret_borrow<py::object>(object);
315         --leaf;
316         break;
317 
318       case Kind::kNone:
319         break;
320 
321       case Kind::kTuple: {
322         if (!PyTuple_CheckExact(object.ptr())) {
323           throw std::invalid_argument(
324               absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
325         }
326         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
327         if (tuple.size() != node.arity) {
328           throw std::invalid_argument(
329               absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
330                               tuple.size(), node.arity, py::repr(object)));
331         }
332         for (py::handle entry : tuple) {
333           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
334         }
335         break;
336       }
337 
338       case Kind::kList: {
339         if (!PyList_CheckExact(object.ptr())) {
340           throw std::invalid_argument(
341               absl::StrFormat("Expected list, got %s.", py::repr(object)));
342         }
343         py::list list = py::reinterpret_borrow<py::list>(object);
344         if (list.size() != node.arity) {
345           throw std::invalid_argument(
346               absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
347                               list.size(), node.arity, py::repr(object)));
348         }
349         for (py::handle entry : list) {
350           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
351         }
352         break;
353       }
354 
355       case Kind::kDict: {
356         if (!PyDict_CheckExact(object.ptr())) {
357           throw std::invalid_argument(
358               absl::StrFormat("Expected dict, got %s.", py::repr(object)));
359         }
360         py::dict dict = py::reinterpret_borrow<py::dict>(object);
361         py::list keys =
362             py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
363         if (PyList_Sort(keys.ptr())) {
364           throw std::runtime_error("Dictionary key sort failed.");
365         }
366         if (keys.not_equal(node.node_data)) {
367           throw std::invalid_argument(
368               absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
369                               py::repr(node.node_data), py::repr(object)));
370         }
371         for (py::handle key : keys) {
372           agenda.push_back(dict[key]);
373         }
374         break;
375       }
376 
377       case Kind::kNamedTuple: {
378         if (!py::isinstance<py::tuple>(object) ||
379             !py::hasattr(object, "_fields")) {
380           throw std::invalid_argument(absl::StrFormat(
381               "Expected named tuple, got %s.", py::repr(object)));
382         }
383         py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
384         if (tuple.size() != node.arity) {
385           throw std::invalid_argument(absl::StrFormat(
386               "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
387               node.arity, py::repr(object)));
388         }
389         if (tuple.get_type().not_equal(node.node_data)) {
390           throw std::invalid_argument(absl::StrFormat(
391               "Named tuple type mismatch: expected type: %s, tuple: %s.",
392               py::repr(node.node_data), py::repr(object)));
393         }
394         for (py::handle entry : tuple) {
395           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
396         }
397         break;
398       }
399 
400       case Kind::kCustom: {
401         auto* registration = CustomNodeRegistry::Lookup(object.get_type());
402         if (registration != node.custom) {
403           throw std::invalid_argument(absl::StrFormat(
404               "Custom node type mismatch: expected type: %s, value: %s.",
405               py::repr(node.custom->type), py::repr(object)));
406         }
407         py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
408         if (out.size() != 2) {
409           throw std::runtime_error(
410               "PyTree custom to_iterable function should return a pair");
411         }
412         if (node.node_data.not_equal(out[1])) {
413           throw std::invalid_argument(absl::StrFormat(
414               "Mismatch custom node data: %s != %s; value: %s.",
415               py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
416         }
417         int arity = 0;
418         for (py::handle entry : py::cast<py::iterable>(out[0])) {
419           ++arity;
420           agenda.push_back(py::reinterpret_borrow<py::object>(entry));
421         }
422         if (arity != node.arity) {
423           throw std::invalid_argument(absl::StrFormat(
424               "Custom type arity mismatch: %d != %d; value: %s.", arity,
425               node.arity, py::repr(object)));
426         }
427         break;
428       }
429     }
430   }
431   if (it != traversal_.rend() || leaf != -1) {
432     throw std::invalid_argument(absl::StrFormat(
433         "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
434   }
435   return leaves;
436 }
437 
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const438 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
439                            py::iterable leaves) const {
440   std::vector<py::object> agenda;
441   auto it = leaves.begin();
442   for (const Node& node : traversal_) {
443     switch (node.kind) {
444       case Kind::kLeaf: {
445         if (it == leaves.end()) {
446           throw std::invalid_argument("Too few leaves for PyTreeDef");
447         }
448 
449         py::object leaf = py::reinterpret_borrow<py::object>(*it);
450         agenda.push_back(f_leaf.is_none() ? std::move(leaf)
451                                           : f_leaf(std::move(leaf)));
452         ++it;
453         break;
454       }
455 
456       case Kind::kNone:
457       case Kind::kTuple:
458       case Kind::kNamedTuple:
459       case Kind::kList:
460       case Kind::kDict:
461       case Kind::kCustom: {
462         if (agenda.size() < node.arity) {
463           throw std::logic_error("Too few elements for custom type.");
464         }
465         py::tuple tuple(node.arity);
466         for (int i = node.arity - 1; i >= 0; --i) {
467           tuple[i] = agenda.back();
468           agenda.pop_back();
469         }
470         agenda.push_back(f_node(tuple));
471       }
472     }
473   }
474   if (it != leaves.end()) {
475     throw std::invalid_argument("Too many leaves for PyTreeDef");
476   }
477   if (agenda.size() != 1) {
478     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
479   }
480   return std::move(agenda.back());
481 }
482 
FromIterableTreeHelper(py::handle xs,std::vector<PyTreeDef::Node>::const_reverse_iterator * it) const483 py::object PyTreeDef::FromIterableTreeHelper(
484     py::handle xs,
485     std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const {
486   if (*it == traversal_.rend()) {
487     throw std::invalid_argument("Tree structures did not match.");
488   }
489   const Node& node = **it;
490   ++*it;
491   if (node.kind == Kind::kLeaf) {
492     return py::reinterpret_borrow<py::object>(xs);
493   }
494   py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
495   std::vector<py::object> ys;
496   ys.reserve(node.arity);
497   for (py::handle x : iterable) {
498     ys.push_back(py::reinterpret_borrow<py::object>(x));
499   }
500   if (ys.size() != node.arity) {
501     throw std::invalid_argument("Arity mismatch between trees");
502   }
503   for (int j = node.arity - 1; j >= 0; --j) {
504     ys[j] = FromIterableTreeHelper(ys[j], it);
505   }
506 
507   return MakeNode(node, absl::MakeSpan(ys));
508 }
509 
FromIterableTree(py::handle xs) const510 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
511   auto it = traversal_.rbegin();
512   py::object out = FromIterableTreeHelper(xs, &it);
513   if (it != traversal_.rend()) {
514     throw std::invalid_argument("Tree structures did not match.");
515   }
516   return out;
517 }
518 
Compose(const PyTreeDef & inner) const519 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
520   auto out = absl::make_unique<PyTreeDef>();
521   for (const Node& n : traversal_) {
522     if (n.kind == Kind::kLeaf) {
523       absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
524     } else {
525       out->traversal_.push_back(n);
526     }
527   }
528   const auto& root = traversal_.back();
529   const auto& inner_root = inner.traversal_.back();
530   // TODO(tomhennigan): This should update all nodes in the traversal.
531   auto& out_root = out->traversal_.back();
532   out_root.num_nodes = (root.num_nodes - root.num_leaves) +
533                        (inner_root.num_nodes * root.num_leaves);
534   out_root.num_leaves *= inner_root.num_leaves;
535   return out;
536 }
537 
Tuple(const std::vector<PyTreeDef> & defs)538 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
539     const std::vector<PyTreeDef>& defs) {
540   auto out = absl::make_unique<PyTreeDef>();
541   for (const PyTreeDef& def : defs) {
542     absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
543   }
544   Node node;
545   node.kind = Kind::kTuple;
546   node.arity = defs.size();
547   out->traversal_.push_back(node);
548   return out;
549 }
550 
Children() const551 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
552   std::vector<std::unique_ptr<PyTreeDef>> children;
553   if (traversal_.empty()) {
554     return children;
555   }
556   Node const& root = traversal_.back();
557   children.resize(root.arity);
558   int pos = traversal_.size() - 1;
559   for (int i = root.arity - 1; i >= 0; --i) {
560     children[i] = absl::make_unique<PyTreeDef>();
561     const Node& node = traversal_.at(pos - 1);
562     if (pos < node.num_nodes) {
563       throw std::logic_error("children() walked off start of array");
564     }
565     std::copy(traversal_.begin() + pos - node.num_nodes,
566               traversal_.begin() + pos,
567               std::back_inserter(children[i]->traversal_));
568     pos -= node.num_nodes;
569   }
570   if (pos != 0) {
571     throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
572   }
573   return children;
574 }
575 
ToString() const576 std::string PyTreeDef::ToString() const {
577   std::vector<std::string> agenda;
578   for (const Node& node : traversal_) {
579     if (agenda.size() < node.arity) {
580       throw std::logic_error("Too few elements for container.");
581     }
582 
583     std::string kind;
584     switch (node.kind) {
585       case Kind::kLeaf:
586         agenda.push_back("*");
587         continue;
588       case Kind::kNone:
589         kind = "None";
590         break;
591       case Kind::kNamedTuple:
592         kind = "namedtuple";
593         break;
594       case Kind::kTuple:
595         kind = "tuple";
596         break;
597       case Kind::kList:
598         kind = "list";
599         break;
600       case Kind::kDict:
601         kind = "dict";
602         break;
603       case Kind::kCustom:
604         kind = static_cast<std::string>(py::str(node.custom->type));
605         break;
606     }
607 
608     std::string children =
609         absl::StrJoin(agenda.end() - node.arity, agenda.end(), ",");
610     agenda.erase(agenda.end() - node.arity, agenda.end());
611 
612     std::string data;
613     if (node.node_data) {
614       data = absl::StrFormat("[%s]", py::str(node.node_data));
615     }
616 
617     agenda.push_back(
618         absl::StrFormat("PyTreeDef(%s%s, [%s])", kind, data, children));
619   }
620 
621   if (agenda.size() != 1) {
622     throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
623   }
624   return std::move(agenda.back());
625 }
626 
BuildPytreeSubmodule(py::module & m)627 void BuildPytreeSubmodule(py::module& m) {
628   py::module pytree = m.def_submodule("pytree", "Python tree library");
629   pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
630              py::arg("leaf_predicate") = absl::nullopt);
631   pytree.def("tuple", &PyTreeDef::Tuple);
632   pytree.def("all_leaves", &PyTreeDef::AllLeaves);
633 
634   py::class_<PyTreeDef>(m, "PyTreeDef")
635       .def("unflatten", &PyTreeDef::Unflatten)
636       .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
637       .def("compose", &PyTreeDef::Compose)
638       .def("walk", &PyTreeDef::Walk)
639       .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
640       .def("children", &PyTreeDef::Children)
641       .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
642       .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
643       .def("__repr__", &PyTreeDef::ToString)
644       .def("__eq__",
645            [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
646       .def("__ne__",
647            [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
648       .def("__hash__",
649            [](const PyTreeDef& t) { return absl::Hash<PyTreeDef>()(t); });
650 
651   pytree.def("register_node", [](py::object type, py::function to_iterable,
652                                  py::function from_iterable) {
653     return CustomNodeRegistry::Register(type, to_iterable, from_iterable);
654   });
655 }
656 
657 }  // namespace xla
658