1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18
19 #include "tensorflow/compiler/xla/python/pytree.h"
20
21 #include <memory>
22 #include <stdexcept>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/hash/hash.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "pybind11/pybind11.h"
33 #include "pybind11/pytypes.h"
34 #include "pybind11/stl.h"
35 #include "tensorflow/compiler/xla/python/absl_casters.h"
36
37 namespace xla {
38
39 namespace py = pybind11;
40
Singleton()41 /*static*/ CustomNodeRegistry* CustomNodeRegistry::Singleton() {
42 static auto* registry = new CustomNodeRegistry;
43 return registry;
44 }
45
Register(py::object type,py::function to_iterable,py::function from_iterable)46 /*static*/ void CustomNodeRegistry::Register(py::object type,
47 py::function to_iterable,
48 py::function from_iterable) {
49 CustomNodeRegistry* registry = Singleton();
50 auto registration = absl::make_unique<Registration>();
51 registration->type = type;
52 registration->to_iterable = std::move(to_iterable);
53 registration->from_iterable = std::move(from_iterable);
54 auto it = registry->registrations_.emplace(type, std::move(registration));
55 if (!it.second) {
56 throw std::invalid_argument(
57 absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
58 py::repr(type)));
59 }
60 }
61
Lookup(py::handle type)62 /*static*/ const CustomNodeRegistry::Registration* CustomNodeRegistry::Lookup(
63 py::handle type) {
64 CustomNodeRegistry* registry = Singleton();
65 auto it =
66 registry->registrations_.find(py::reinterpret_borrow<py::object>(type));
67 return it == registry->registrations_.end() ? nullptr : it->second.get();
68 }
69
operator ==(const PyTreeDef & other) const70 bool PyTreeDef::operator==(const PyTreeDef& other) const {
71 if (traversal_.size() != other.traversal_.size()) {
72 return false;
73 }
74 for (size_t i = 0; i < traversal_.size(); ++i) {
75 const Node& a = traversal_[i];
76 const Node& b = other.traversal_[i];
77 if (a.kind != b.kind || a.arity != b.arity ||
78 (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
79 a.custom != b.custom) {
80 return false;
81 }
82 if (a.node_data && a.node_data.not_equal(b.node_data)) {
83 return false;
84 }
85 // We don't need to test equality of num_leaves and num_nodes since they
86 // are derivable from the other node data.
87 }
88 return true;
89 }
90
GetKind(const py::handle & obj,CustomNodeRegistry::Registration const ** custom)91 /*static*/ PyTreeDef::Kind PyTreeDef::GetKind(
92 const py::handle& obj, CustomNodeRegistry::Registration const** custom) {
93 const PyObject* ptr = obj.ptr();
94 if (PyTuple_CheckExact(ptr)) return Kind::kTuple;
95 if (PyList_CheckExact(ptr)) return Kind::kList;
96 if (PyDict_CheckExact(ptr)) return Kind::kDict;
97 if ((*custom = CustomNodeRegistry::Lookup(obj.get_type()))) {
98 return Kind::kCustom;
99 } else if (py::isinstance<py::none>(obj)) {
100 return Kind::kNone;
101 } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
102 // We can only identify namedtuples heuristically, here by the presence of
103 // a _fields attribute.
104 return Kind::kNamedTuple;
105 } else {
106 return Kind::kLeaf;
107 }
108 }
109
FlattenInto(py::handle handle,std::vector<py::object> & leaves,absl::optional<py::function> leaf_predicate)110 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
111 absl::optional<py::function> leaf_predicate) {
112 Node node;
113 int start_num_nodes = traversal_.size();
114 int start_num_leaves = leaves.size();
115 if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
116 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
117 } else {
118 node.kind = GetKind(handle, &node.custom);
119 auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
120 FlattenInto(child, leaves, leaf_predicate);
121 };
122 if (node.kind == Kind::kNone) {
123 // Nothing to do.
124 } else if (node.kind == Kind::kTuple) {
125 py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
126 node.arity = tuple.size();
127 for (py::handle entry : tuple) {
128 recurse(entry);
129 }
130 } else if (node.kind == Kind::kList) {
131 py::list list = py::reinterpret_borrow<py::list>(handle);
132 node.arity = list.size();
133 for (py::handle entry : list) {
134 recurse(entry);
135 }
136 } else if (node.kind == Kind::kDict) {
137 py::dict dict = py::reinterpret_borrow<py::dict>(handle);
138 py::list keys = py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
139 if (PyList_Sort(keys.ptr())) {
140 throw std::runtime_error("Dictionary key sort failed.");
141 }
142 for (py::handle key : keys) {
143 recurse(dict[key]);
144 }
145 node.arity = dict.size();
146 node.node_data = std::move(keys);
147 } else if (node.kind == Kind::kCustom) {
148 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
149 if (out.size() != 2) {
150 throw std::runtime_error(
151 "PyTree custom to_iterable function should return a pair");
152 }
153 node.node_data = out[1];
154 node.arity = 0;
155 for (py::handle entry : py::cast<py::iterable>(out[0])) {
156 ++node.arity;
157 recurse(entry);
158 }
159 } else if (node.kind == Kind::kNamedTuple) {
160 py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
161 node.arity = tuple.size();
162 node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
163 for (py::handle entry : tuple) {
164 recurse(entry);
165 }
166 } else {
167 assert(node.kind == Kind::kLeaf);
168 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
169 }
170 }
171 node.num_nodes = traversal_.size() - start_num_nodes + 1;
172 node.num_leaves = leaves.size() - start_num_leaves;
173 traversal_.push_back(std::move(node));
174 }
175
176 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,absl::optional<py::function> leaf_predicate)177 PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
178 std::vector<py::object> leaves;
179 auto tree = absl::make_unique<PyTreeDef>();
180 tree->FlattenInto(x, leaves, leaf_predicate);
181 return std::make_pair(std::move(leaves), std::move(tree));
182 }
183
AllLeaves(const py::iterable & x)184 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
185 const CustomNodeRegistry::Registration* custom;
186 for (const py::handle& h : x) {
187 if (GetKind(h, &custom) != Kind::kLeaf) return false;
188 }
189 return true;
190 }
191
Unflatten(py::iterable leaves) const192 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
193 std::vector<py::object> agenda;
194 auto it = leaves.begin();
195 int leaf_count = 0;
196 for (const Node& node : traversal_) {
197 if (agenda.size() < node.arity) {
198 throw std::logic_error("Too few elements for TreeDef node.");
199 }
200 switch (node.kind) {
201 case Kind::kLeaf:
202 if (it == leaves.end()) {
203 throw std::invalid_argument(absl::StrFormat(
204 "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
205 leaf_count));
206 }
207 agenda.push_back(py::reinterpret_borrow<py::object>(*it));
208 ++it;
209 ++leaf_count;
210 break;
211
212 case Kind::kNone:
213 case Kind::kTuple:
214 case Kind::kNamedTuple:
215 case Kind::kList:
216 case Kind::kDict:
217 case Kind::kCustom: {
218 const int size = agenda.size();
219 absl::Span<py::object> span;
220 if (node.arity > 0) {
221 span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
222 }
223 py::object o = MakeNode(node, span);
224 agenda.resize(size - node.arity);
225 agenda.push_back(o);
226 break;
227 }
228 }
229 }
230 if (it != leaves.end()) {
231 throw std::invalid_argument(absl::StrFormat(
232 "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
233 }
234 if (agenda.size() != 1) {
235 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
236 }
237 return std::move(agenda.back());
238 }
239
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)240 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
241 absl::Span<py::object> children) {
242 if (children.size() != node.arity) {
243 throw std::logic_error("Node arity mismatch.");
244 }
245 switch (node.kind) {
246 case Kind::kLeaf:
247 throw std::logic_error("MakeNode not implemented for leaves.");
248
249 case Kind::kNone:
250 return py::none();
251
252 case Kind::kTuple:
253 case Kind::kNamedTuple: {
254 py::tuple tuple(node.arity);
255 for (int i = 0; i < node.arity; ++i) {
256 tuple[i] = std::move(children[i]);
257 }
258 if (node.kind == Kind::kNamedTuple) {
259 return node.node_data(*tuple);
260 } else {
261 return std::move(tuple);
262 }
263 }
264
265 case Kind::kList: {
266 py::list list(node.arity);
267 for (int i = 0; i < node.arity; ++i) {
268 list[i] = std::move(children[i]);
269 }
270 return std::move(list);
271 }
272
273 case Kind::kDict: {
274 py::dict dict;
275 py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
276 for (int i = 0; i < node.arity; ++i) {
277 dict[keys[i]] = std::move(children[i]);
278 }
279 return std::move(dict);
280 break;
281 }
282 case Kind::kCustom: {
283 py::tuple tuple(node.arity);
284 for (int i = 0; i < node.arity; ++i) {
285 tuple[i] = std::move(children[i]);
286 }
287 return node.custom->from_iterable(node.node_data, tuple);
288 }
289 }
290 throw std::logic_error("Unreachable code.");
291 }
292
FlattenUpTo(py::handle xs) const293 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
294 py::list leaves(num_leaves());
295 std::vector<py::object> agenda;
296 agenda.push_back(py::reinterpret_borrow<py::object>(xs));
297 auto it = traversal_.rbegin();
298 int leaf = num_leaves() - 1;
299 while (!agenda.empty()) {
300 if (it == traversal_.rend()) {
301 throw std::invalid_argument(absl::StrFormat(
302 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
303 }
304 const Node& node = *it;
305 py::object object = agenda.back();
306 agenda.pop_back();
307 ++it;
308
309 switch (node.kind) {
310 case Kind::kLeaf:
311 if (leaf < 0) {
312 throw std::logic_error("Leaf count mismatch.");
313 }
314 leaves[leaf] = py::reinterpret_borrow<py::object>(object);
315 --leaf;
316 break;
317
318 case Kind::kNone:
319 break;
320
321 case Kind::kTuple: {
322 if (!PyTuple_CheckExact(object.ptr())) {
323 throw std::invalid_argument(
324 absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
325 }
326 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
327 if (tuple.size() != node.arity) {
328 throw std::invalid_argument(
329 absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
330 tuple.size(), node.arity, py::repr(object)));
331 }
332 for (py::handle entry : tuple) {
333 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
334 }
335 break;
336 }
337
338 case Kind::kList: {
339 if (!PyList_CheckExact(object.ptr())) {
340 throw std::invalid_argument(
341 absl::StrFormat("Expected list, got %s.", py::repr(object)));
342 }
343 py::list list = py::reinterpret_borrow<py::list>(object);
344 if (list.size() != node.arity) {
345 throw std::invalid_argument(
346 absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
347 list.size(), node.arity, py::repr(object)));
348 }
349 for (py::handle entry : list) {
350 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
351 }
352 break;
353 }
354
355 case Kind::kDict: {
356 if (!PyDict_CheckExact(object.ptr())) {
357 throw std::invalid_argument(
358 absl::StrFormat("Expected dict, got %s.", py::repr(object)));
359 }
360 py::dict dict = py::reinterpret_borrow<py::dict>(object);
361 py::list keys =
362 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
363 if (PyList_Sort(keys.ptr())) {
364 throw std::runtime_error("Dictionary key sort failed.");
365 }
366 if (keys.not_equal(node.node_data)) {
367 throw std::invalid_argument(
368 absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
369 py::repr(node.node_data), py::repr(object)));
370 }
371 for (py::handle key : keys) {
372 agenda.push_back(dict[key]);
373 }
374 break;
375 }
376
377 case Kind::kNamedTuple: {
378 if (!py::isinstance<py::tuple>(object) ||
379 !py::hasattr(object, "_fields")) {
380 throw std::invalid_argument(absl::StrFormat(
381 "Expected named tuple, got %s.", py::repr(object)));
382 }
383 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
384 if (tuple.size() != node.arity) {
385 throw std::invalid_argument(absl::StrFormat(
386 "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
387 node.arity, py::repr(object)));
388 }
389 if (tuple.get_type().not_equal(node.node_data)) {
390 throw std::invalid_argument(absl::StrFormat(
391 "Named tuple type mismatch: expected type: %s, tuple: %s.",
392 py::repr(node.node_data), py::repr(object)));
393 }
394 for (py::handle entry : tuple) {
395 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
396 }
397 break;
398 }
399
400 case Kind::kCustom: {
401 auto* registration = CustomNodeRegistry::Lookup(object.get_type());
402 if (registration != node.custom) {
403 throw std::invalid_argument(absl::StrFormat(
404 "Custom node type mismatch: expected type: %s, value: %s.",
405 py::repr(node.custom->type), py::repr(object)));
406 }
407 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
408 if (out.size() != 2) {
409 throw std::runtime_error(
410 "PyTree custom to_iterable function should return a pair");
411 }
412 if (node.node_data.not_equal(out[1])) {
413 throw std::invalid_argument(absl::StrFormat(
414 "Mismatch custom node data: %s != %s; value: %s.",
415 py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
416 }
417 int arity = 0;
418 for (py::handle entry : py::cast<py::iterable>(out[0])) {
419 ++arity;
420 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
421 }
422 if (arity != node.arity) {
423 throw std::invalid_argument(absl::StrFormat(
424 "Custom type arity mismatch: %d != %d; value: %s.", arity,
425 node.arity, py::repr(object)));
426 }
427 break;
428 }
429 }
430 }
431 if (it != traversal_.rend() || leaf != -1) {
432 throw std::invalid_argument(absl::StrFormat(
433 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
434 }
435 return leaves;
436 }
437
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const438 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
439 py::iterable leaves) const {
440 std::vector<py::object> agenda;
441 auto it = leaves.begin();
442 for (const Node& node : traversal_) {
443 switch (node.kind) {
444 case Kind::kLeaf: {
445 if (it == leaves.end()) {
446 throw std::invalid_argument("Too few leaves for PyTreeDef");
447 }
448
449 py::object leaf = py::reinterpret_borrow<py::object>(*it);
450 agenda.push_back(f_leaf.is_none() ? std::move(leaf)
451 : f_leaf(std::move(leaf)));
452 ++it;
453 break;
454 }
455
456 case Kind::kNone:
457 case Kind::kTuple:
458 case Kind::kNamedTuple:
459 case Kind::kList:
460 case Kind::kDict:
461 case Kind::kCustom: {
462 if (agenda.size() < node.arity) {
463 throw std::logic_error("Too few elements for custom type.");
464 }
465 py::tuple tuple(node.arity);
466 for (int i = node.arity - 1; i >= 0; --i) {
467 tuple[i] = agenda.back();
468 agenda.pop_back();
469 }
470 agenda.push_back(f_node(tuple));
471 }
472 }
473 }
474 if (it != leaves.end()) {
475 throw std::invalid_argument("Too many leaves for PyTreeDef");
476 }
477 if (agenda.size() != 1) {
478 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
479 }
480 return std::move(agenda.back());
481 }
482
FromIterableTreeHelper(py::handle xs,std::vector<PyTreeDef::Node>::const_reverse_iterator * it) const483 py::object PyTreeDef::FromIterableTreeHelper(
484 py::handle xs,
485 std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const {
486 if (*it == traversal_.rend()) {
487 throw std::invalid_argument("Tree structures did not match.");
488 }
489 const Node& node = **it;
490 ++*it;
491 if (node.kind == Kind::kLeaf) {
492 return py::reinterpret_borrow<py::object>(xs);
493 }
494 py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
495 std::vector<py::object> ys;
496 ys.reserve(node.arity);
497 for (py::handle x : iterable) {
498 ys.push_back(py::reinterpret_borrow<py::object>(x));
499 }
500 if (ys.size() != node.arity) {
501 throw std::invalid_argument("Arity mismatch between trees");
502 }
503 for (int j = node.arity - 1; j >= 0; --j) {
504 ys[j] = FromIterableTreeHelper(ys[j], it);
505 }
506
507 return MakeNode(node, absl::MakeSpan(ys));
508 }
509
FromIterableTree(py::handle xs) const510 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
511 auto it = traversal_.rbegin();
512 py::object out = FromIterableTreeHelper(xs, &it);
513 if (it != traversal_.rend()) {
514 throw std::invalid_argument("Tree structures did not match.");
515 }
516 return out;
517 }
518
Compose(const PyTreeDef & inner) const519 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
520 auto out = absl::make_unique<PyTreeDef>();
521 for (const Node& n : traversal_) {
522 if (n.kind == Kind::kLeaf) {
523 absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
524 } else {
525 out->traversal_.push_back(n);
526 }
527 }
528 const auto& root = traversal_.back();
529 const auto& inner_root = inner.traversal_.back();
530 // TODO(tomhennigan): This should update all nodes in the traversal.
531 auto& out_root = out->traversal_.back();
532 out_root.num_nodes = (root.num_nodes - root.num_leaves) +
533 (inner_root.num_nodes * root.num_leaves);
534 out_root.num_leaves *= inner_root.num_leaves;
535 return out;
536 }
537
Tuple(const std::vector<PyTreeDef> & defs)538 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
539 const std::vector<PyTreeDef>& defs) {
540 auto out = absl::make_unique<PyTreeDef>();
541 for (const PyTreeDef& def : defs) {
542 absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
543 }
544 Node node;
545 node.kind = Kind::kTuple;
546 node.arity = defs.size();
547 out->traversal_.push_back(node);
548 return out;
549 }
550
Children() const551 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
552 std::vector<std::unique_ptr<PyTreeDef>> children;
553 if (traversal_.empty()) {
554 return children;
555 }
556 Node const& root = traversal_.back();
557 children.resize(root.arity);
558 int pos = traversal_.size() - 1;
559 for (int i = root.arity - 1; i >= 0; --i) {
560 children[i] = absl::make_unique<PyTreeDef>();
561 const Node& node = traversal_.at(pos - 1);
562 if (pos < node.num_nodes) {
563 throw std::logic_error("children() walked off start of array");
564 }
565 std::copy(traversal_.begin() + pos - node.num_nodes,
566 traversal_.begin() + pos,
567 std::back_inserter(children[i]->traversal_));
568 pos -= node.num_nodes;
569 }
570 if (pos != 0) {
571 throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
572 }
573 return children;
574 }
575
ToString() const576 std::string PyTreeDef::ToString() const {
577 std::vector<std::string> agenda;
578 for (const Node& node : traversal_) {
579 if (agenda.size() < node.arity) {
580 throw std::logic_error("Too few elements for container.");
581 }
582
583 std::string kind;
584 switch (node.kind) {
585 case Kind::kLeaf:
586 agenda.push_back("*");
587 continue;
588 case Kind::kNone:
589 kind = "None";
590 break;
591 case Kind::kNamedTuple:
592 kind = "namedtuple";
593 break;
594 case Kind::kTuple:
595 kind = "tuple";
596 break;
597 case Kind::kList:
598 kind = "list";
599 break;
600 case Kind::kDict:
601 kind = "dict";
602 break;
603 case Kind::kCustom:
604 kind = static_cast<std::string>(py::str(node.custom->type));
605 break;
606 }
607
608 std::string children =
609 absl::StrJoin(agenda.end() - node.arity, agenda.end(), ",");
610 agenda.erase(agenda.end() - node.arity, agenda.end());
611
612 std::string data;
613 if (node.node_data) {
614 data = absl::StrFormat("[%s]", py::str(node.node_data));
615 }
616
617 agenda.push_back(
618 absl::StrFormat("PyTreeDef(%s%s, [%s])", kind, data, children));
619 }
620
621 if (agenda.size() != 1) {
622 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
623 }
624 return std::move(agenda.back());
625 }
626
BuildPytreeSubmodule(py::module & m)627 void BuildPytreeSubmodule(py::module& m) {
628 py::module pytree = m.def_submodule("pytree", "Python tree library");
629 pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
630 py::arg("leaf_predicate") = absl::nullopt);
631 pytree.def("tuple", &PyTreeDef::Tuple);
632 pytree.def("all_leaves", &PyTreeDef::AllLeaves);
633
634 py::class_<PyTreeDef>(m, "PyTreeDef")
635 .def("unflatten", &PyTreeDef::Unflatten)
636 .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
637 .def("compose", &PyTreeDef::Compose)
638 .def("walk", &PyTreeDef::Walk)
639 .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
640 .def("children", &PyTreeDef::Children)
641 .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
642 .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
643 .def("__repr__", &PyTreeDef::ToString)
644 .def("__eq__",
645 [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
646 .def("__ne__",
647 [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
648 .def("__hash__",
649 [](const PyTreeDef& t) { return absl::Hash<PyTreeDef>()(t); });
650
651 pytree.def("register_node", [](py::object type, py::function to_iterable,
652 py::function from_iterable) {
653 return CustomNodeRegistry::Register(type, to_iterable, from_iterable);
654 });
655 }
656
657 } // namespace xla
658