1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Caution: this code uses exceptions. The exception use is local to the
17 // binding code and the idiomatic way to emit Python exceptions.
18
19 #include "tensorflow/compiler/xla/python/pytree.h"
20
21 #include <memory>
22 #include <stdexcept>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/flat_hash_map.h"
28 #include "absl/hash/hash.h"
29 #include "absl/memory/memory.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/strings/str_join.h"
32 #include "pybind11/pybind11.h"
33 #include "pybind11/pytypes.h"
34 #include "pybind11/stl.h"
35 #include "tensorflow/compiler/xla/python/absl_casters.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace xla {
39
40 namespace py = pybind11;
41
Singleton()42 /*static*/ PyTreeTypeRegistry* PyTreeTypeRegistry::Singleton() {
43 static auto* registry = []() -> PyTreeTypeRegistry* {
44 auto* registry = new PyTreeTypeRegistry;
45
46 auto add_builtin_type = [&](PyTypeObject* type_obj, PyTreeKind kind) {
47 py::object type = py::reinterpret_borrow<py::object>(
48 reinterpret_cast<PyObject*>(type_obj));
49 auto registration = absl::make_unique<Registration>();
50 registration->kind = kind;
51 registration->type = type;
52 CHECK(registry->registrations_.emplace(type, std::move(registration))
53 .second);
54 };
55 add_builtin_type(Py_TYPE(Py_None), PyTreeKind::kNone);
56 add_builtin_type(&PyTuple_Type, PyTreeKind::kTuple);
57 add_builtin_type(&PyList_Type, PyTreeKind::kList);
58 add_builtin_type(&PyDict_Type, PyTreeKind::kDict);
59 return registry;
60 }();
61 return registry;
62 }
63
Register(py::object type,py::function to_iterable,py::function from_iterable)64 /*static*/ void PyTreeTypeRegistry::Register(py::object type,
65 py::function to_iterable,
66 py::function from_iterable) {
67 PyTreeTypeRegistry* registry = Singleton();
68 auto registration = absl::make_unique<Registration>();
69 registration->kind = PyTreeKind::kCustom;
70 registration->type = type;
71 registration->to_iterable = std::move(to_iterable);
72 registration->from_iterable = std::move(from_iterable);
73 auto it = registry->registrations_.emplace(type, std::move(registration));
74 if (!it.second) {
75 throw std::invalid_argument(
76 absl::StrFormat("Duplicate custom PyTreeDef type registration for %s.",
77 py::repr(type)));
78 }
79 }
80
Lookup(py::handle type)81 /*static*/ const PyTreeTypeRegistry::Registration* PyTreeTypeRegistry::Lookup(
82 py::handle type) {
83 PyTreeTypeRegistry* registry = Singleton();
84 auto it = registry->registrations_.find(type);
85 return it == registry->registrations_.end() ? nullptr : it->second.get();
86 }
87
operator ==(const PyTreeDef & other) const88 bool PyTreeDef::operator==(const PyTreeDef& other) const {
89 if (traversal_.size() != other.traversal_.size()) {
90 return false;
91 }
92 for (size_t i = 0; i < traversal_.size(); ++i) {
93 const Node& a = traversal_[i];
94 const Node& b = other.traversal_[i];
95 if (a.kind != b.kind || a.arity != b.arity ||
96 (a.node_data.ptr() == nullptr) != (b.node_data.ptr() == nullptr) ||
97 a.custom != b.custom) {
98 return false;
99 }
100 if (a.node_data && a.node_data.not_equal(b.node_data)) {
101 return false;
102 }
103 // We don't need to test equality of num_leaves and num_nodes since they
104 // are derivable from the other node data.
105 }
106 return true;
107 }
108
GetKind(const py::handle & obj,PyTreeTypeRegistry::Registration const ** custom)109 /*static*/ PyTreeKind PyTreeDef::GetKind(
110 const py::handle& obj, PyTreeTypeRegistry::Registration const** custom) {
111 const PyTreeTypeRegistry::Registration* registration =
112 PyTreeTypeRegistry::Lookup(obj.get_type());
113 if (registration) {
114 *custom = registration;
115 return registration->kind;
116 } else if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
117 // We can only identify namedtuples heuristically, here by the presence of
118 // a _fields attribute.
119 return PyTreeKind::kNamedTuple;
120 } else {
121 return PyTreeKind::kLeaf;
122 }
123 }
124
125 template <typename T>
FlattenIntoImpl(py::handle handle,T & leaves,const absl::optional<py::function> & leaf_predicate)126 void PyTreeDef::FlattenIntoImpl(
127 py::handle handle, T& leaves,
128 const absl::optional<py::function>& leaf_predicate) {
129 Node node;
130 int start_num_nodes = traversal_.size();
131 int start_num_leaves = leaves.size();
132 if (leaf_predicate && (*leaf_predicate)(handle).cast<bool>()) {
133 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
134 } else {
135 node.kind = GetKind(handle, &node.custom);
136 auto recurse = [this, &leaf_predicate, &leaves](py::handle child) {
137 FlattenInto(child, leaves, leaf_predicate);
138 };
139 switch (node.kind) {
140 case PyTreeKind::kNone:
141 // Nothing to do.
142 break;
143 case PyTreeKind::kTuple: {
144 node.arity = PyTuple_GET_SIZE(handle.ptr());
145 for (int i = 0; i < node.arity; ++i) {
146 recurse(PyTuple_GET_ITEM(handle.ptr(), i));
147 }
148 break;
149 }
150 case PyTreeKind::kList: {
151 node.arity = PyList_GET_SIZE(handle.ptr());
152 for (int i = 0; i < node.arity; ++i) {
153 recurse(PyList_GET_ITEM(handle.ptr(), i));
154 }
155 break;
156 }
157 case PyTreeKind::kDict: {
158 py::dict dict = py::reinterpret_borrow<py::dict>(handle);
159 py::list keys =
160 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
161 if (PyList_Sort(keys.ptr())) {
162 throw std::runtime_error("Dictionary key sort failed.");
163 }
164 for (py::handle key : keys) {
165 recurse(dict[key]);
166 }
167 node.arity = dict.size();
168 node.node_data = std::move(keys);
169 break;
170 }
171 case PyTreeKind::kCustom: {
172 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(handle));
173 if (out.size() != 2) {
174 throw std::runtime_error(
175 "PyTree custom to_iterable function should return a pair");
176 }
177 node.node_data = out[1];
178 node.arity = 0;
179 for (py::handle entry : py::cast<py::iterable>(out[0])) {
180 ++node.arity;
181 recurse(entry);
182 }
183 break;
184 }
185 case PyTreeKind::kNamedTuple: {
186 py::tuple tuple = py::reinterpret_borrow<py::tuple>(handle);
187 node.arity = tuple.size();
188 node.node_data = py::reinterpret_borrow<py::object>(tuple.get_type());
189 for (py::handle entry : tuple) {
190 recurse(entry);
191 }
192 break;
193 }
194 default:
195 DCHECK(node.kind == PyTreeKind::kLeaf);
196 leaves.push_back(py::reinterpret_borrow<py::object>(handle));
197 }
198 }
199 node.num_nodes = traversal_.size() - start_num_nodes + 1;
200 node.num_leaves = leaves.size() - start_num_leaves;
201 traversal_.push_back(std::move(node));
202 }
203
FlattenInto(py::handle handle,absl::InlinedVector<py::object,2> & leaves,absl::optional<py::function> leaf_predicate)204 void PyTreeDef::FlattenInto(py::handle handle,
205 absl::InlinedVector<py::object, 2>& leaves,
206 absl::optional<py::function> leaf_predicate) {
207 FlattenIntoImpl(handle, leaves, leaf_predicate);
208 }
209
FlattenInto(py::handle handle,std::vector<py::object> & leaves,absl::optional<py::function> leaf_predicate)210 void PyTreeDef::FlattenInto(py::handle handle, std::vector<py::object>& leaves,
211 absl::optional<py::function> leaf_predicate) {
212 FlattenIntoImpl(handle, leaves, leaf_predicate);
213 }
214
215 /*static*/ std::pair<std::vector<py::object>, std::unique_ptr<PyTreeDef>>
Flatten(py::handle x,absl::optional<py::function> leaf_predicate)216 PyTreeDef::Flatten(py::handle x, absl::optional<py::function> leaf_predicate) {
217 std::vector<py::object> leaves;
218 auto tree = absl::make_unique<PyTreeDef>();
219 tree->FlattenInto(x, leaves, leaf_predicate);
220 return std::make_pair(std::move(leaves), std::move(tree));
221 }
222
AllLeaves(const py::iterable & x)223 /*static*/ bool PyTreeDef::AllLeaves(const py::iterable& x) {
224 const PyTreeTypeRegistry::Registration* custom;
225 for (const py::handle& h : x) {
226 if (GetKind(h, &custom) != PyTreeKind::kLeaf) return false;
227 }
228 return true;
229 }
230
231 template <typename T>
UnflattenImpl(T leaves) const232 py::object PyTreeDef::UnflattenImpl(T leaves) const {
233 absl::InlinedVector<py::object, 4> agenda;
234 auto it = leaves.begin();
235 int leaf_count = 0;
236 for (const Node& node : traversal_) {
237 if (agenda.size() < node.arity) {
238 throw std::logic_error("Too few elements for TreeDef node.");
239 }
240 switch (node.kind) {
241 case PyTreeKind::kLeaf:
242 if (it == leaves.end()) {
243 throw std::invalid_argument(absl::StrFormat(
244 "Too few leaves for PyTreeDef; expected %d, got %d", num_leaves(),
245 leaf_count));
246 }
247 agenda.push_back(py::reinterpret_borrow<py::object>(*it));
248 ++it;
249 ++leaf_count;
250 break;
251
252 case PyTreeKind::kNone:
253 case PyTreeKind::kTuple:
254 case PyTreeKind::kNamedTuple:
255 case PyTreeKind::kList:
256 case PyTreeKind::kDict:
257 case PyTreeKind::kCustom: {
258 const int size = agenda.size();
259 absl::Span<py::object> span;
260 if (node.arity > 0) {
261 span = absl::Span<py::object>(&agenda[size - node.arity], node.arity);
262 }
263 py::object o = MakeNode(node, span);
264 agenda.resize(size - node.arity);
265 agenda.push_back(o);
266 break;
267 }
268 }
269 }
270 if (it != leaves.end()) {
271 throw std::invalid_argument(absl::StrFormat(
272 "Too many leaves for PyTreeDef; expected %d.", num_leaves()));
273 }
274 if (agenda.size() != 1) {
275 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
276 }
277 return std::move(agenda.back());
278 }
279
Unflatten(py::iterable leaves) const280 py::object PyTreeDef::Unflatten(py::iterable leaves) const {
281 return UnflattenImpl(leaves);
282 }
283
Unflatten(absl::Span<const py::object> leaves) const284 py::object PyTreeDef::Unflatten(absl::Span<const py::object> leaves) const {
285 return UnflattenImpl(leaves);
286 }
287
MakeNode(const PyTreeDef::Node & node,absl::Span<py::object> children)288 /*static*/ py::object PyTreeDef::MakeNode(const PyTreeDef::Node& node,
289 absl::Span<py::object> children) {
290 if (children.size() != node.arity) {
291 throw std::logic_error("Node arity mismatch.");
292 }
293 switch (node.kind) {
294 case PyTreeKind::kLeaf:
295 throw std::logic_error("MakeNode not implemented for leaves.");
296
297 case PyTreeKind::kNone:
298 return py::none();
299
300 case PyTreeKind::kTuple:
301 case PyTreeKind::kNamedTuple: {
302 py::tuple tuple(node.arity);
303 for (int i = 0; i < node.arity; ++i) {
304 tuple[i] = std::move(children[i]);
305 }
306 if (node.kind == PyTreeKind::kNamedTuple) {
307 return node.node_data(*tuple);
308 } else {
309 return std::move(tuple);
310 }
311 }
312
313 case PyTreeKind::kList: {
314 py::list list(node.arity);
315 for (int i = 0; i < node.arity; ++i) {
316 list[i] = std::move(children[i]);
317 }
318 return std::move(list);
319 }
320
321 case PyTreeKind::kDict: {
322 py::dict dict;
323 py::list keys = py::reinterpret_borrow<py::list>(node.node_data);
324 for (int i = 0; i < node.arity; ++i) {
325 dict[keys[i]] = std::move(children[i]);
326 }
327 return std::move(dict);
328 break;
329 }
330 case PyTreeKind::kCustom: {
331 py::tuple tuple(node.arity);
332 for (int i = 0; i < node.arity; ++i) {
333 tuple[i] = std::move(children[i]);
334 }
335 return node.custom->from_iterable(node.node_data, tuple);
336 }
337 }
338 throw std::logic_error("Unreachable code.");
339 }
340
FlattenUpTo(py::handle xs) const341 py::list PyTreeDef::FlattenUpTo(py::handle xs) const {
342 py::list leaves(num_leaves());
343 std::vector<py::object> agenda;
344 agenda.push_back(py::reinterpret_borrow<py::object>(xs));
345 auto it = traversal_.rbegin();
346 int leaf = num_leaves() - 1;
347 while (!agenda.empty()) {
348 if (it == traversal_.rend()) {
349 throw std::invalid_argument(absl::StrFormat(
350 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
351 }
352 const Node& node = *it;
353 py::object object = agenda.back();
354 agenda.pop_back();
355 ++it;
356
357 switch (node.kind) {
358 case PyTreeKind::kLeaf:
359 if (leaf < 0) {
360 throw std::logic_error("Leaf count mismatch.");
361 }
362 leaves[leaf] = py::reinterpret_borrow<py::object>(object);
363 --leaf;
364 break;
365
366 case PyTreeKind::kNone:
367 break;
368
369 case PyTreeKind::kTuple: {
370 if (!PyTuple_CheckExact(object.ptr())) {
371 throw std::invalid_argument(
372 absl::StrFormat("Expected tuple, got %s.", py::repr(object)));
373 }
374 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
375 if (tuple.size() != node.arity) {
376 throw std::invalid_argument(
377 absl::StrFormat("Tuple arity mismatch: %d != %d; tuple: %s.",
378 tuple.size(), node.arity, py::repr(object)));
379 }
380 for (py::handle entry : tuple) {
381 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
382 }
383 break;
384 }
385
386 case PyTreeKind::kList: {
387 if (!PyList_CheckExact(object.ptr())) {
388 throw std::invalid_argument(
389 absl::StrFormat("Expected list, got %s.", py::repr(object)));
390 }
391 py::list list = py::reinterpret_borrow<py::list>(object);
392 if (list.size() != node.arity) {
393 throw std::invalid_argument(
394 absl::StrFormat("List arity mismatch: %d != %d; list: %s.",
395 list.size(), node.arity, py::repr(object)));
396 }
397 for (py::handle entry : list) {
398 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
399 }
400 break;
401 }
402
403 case PyTreeKind::kDict: {
404 if (!PyDict_CheckExact(object.ptr())) {
405 throw std::invalid_argument(
406 absl::StrFormat("Expected dict, got %s.", py::repr(object)));
407 }
408 py::dict dict = py::reinterpret_borrow<py::dict>(object);
409 py::list keys =
410 py::reinterpret_steal<py::list>(PyDict_Keys(dict.ptr()));
411 if (PyList_Sort(keys.ptr())) {
412 throw std::runtime_error("Dictionary key sort failed.");
413 }
414 if (keys.not_equal(node.node_data)) {
415 throw std::invalid_argument(
416 absl::StrFormat("Dict key mismatch; expected keys: %s; dict: %s.",
417 py::repr(node.node_data), py::repr(object)));
418 }
419 for (py::handle key : keys) {
420 agenda.push_back(dict[key]);
421 }
422 break;
423 }
424
425 case PyTreeKind::kNamedTuple: {
426 if (!py::isinstance<py::tuple>(object) ||
427 !py::hasattr(object, "_fields")) {
428 throw std::invalid_argument(absl::StrFormat(
429 "Expected named tuple, got %s.", py::repr(object)));
430 }
431 py::tuple tuple = py::reinterpret_borrow<py::tuple>(object);
432 if (tuple.size() != node.arity) {
433 throw std::invalid_argument(absl::StrFormat(
434 "Named tuple arity mismatch: %d != %d; tuple: %s.", tuple.size(),
435 node.arity, py::repr(object)));
436 }
437 if (tuple.get_type().not_equal(node.node_data)) {
438 throw std::invalid_argument(absl::StrFormat(
439 "Named tuple type mismatch: expected type: %s, tuple: %s.",
440 py::repr(node.node_data), py::repr(object)));
441 }
442 for (py::handle entry : tuple) {
443 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
444 }
445 break;
446 }
447
448 case PyTreeKind::kCustom: {
449 auto* registration = PyTreeTypeRegistry::Lookup(object.get_type());
450 if (registration != node.custom) {
451 throw std::invalid_argument(absl::StrFormat(
452 "Custom node type mismatch: expected type: %s, value: %s.",
453 py::repr(node.custom->type), py::repr(object)));
454 }
455 py::tuple out = py::cast<py::tuple>(node.custom->to_iterable(object));
456 if (out.size() != 2) {
457 throw std::runtime_error(
458 "PyTree custom to_iterable function should return a pair");
459 }
460 if (node.node_data.not_equal(out[1])) {
461 throw std::invalid_argument(absl::StrFormat(
462 "Mismatch custom node data: %s != %s; value: %s.",
463 py::repr(node.node_data), py::repr(out[1]), py::repr(object)));
464 }
465 int arity = 0;
466 for (py::handle entry : py::cast<py::iterable>(out[0])) {
467 ++arity;
468 agenda.push_back(py::reinterpret_borrow<py::object>(entry));
469 }
470 if (arity != node.arity) {
471 throw std::invalid_argument(absl::StrFormat(
472 "Custom type arity mismatch: %d != %d; value: %s.", arity,
473 node.arity, py::repr(object)));
474 }
475 break;
476 }
477 }
478 }
479 if (it != traversal_.rend() || leaf != -1) {
480 throw std::invalid_argument(absl::StrFormat(
481 "Tree structures did not match: %s vs %s", py::repr(xs), ToString()));
482 }
483 return leaves;
484 }
485
Walk(const py::function & f_node,py::handle f_leaf,py::iterable leaves) const486 py::object PyTreeDef::Walk(const py::function& f_node, py::handle f_leaf,
487 py::iterable leaves) const {
488 std::vector<py::object> agenda;
489 auto it = leaves.begin();
490 for (const Node& node : traversal_) {
491 switch (node.kind) {
492 case PyTreeKind::kLeaf: {
493 if (it == leaves.end()) {
494 throw std::invalid_argument("Too few leaves for PyTreeDef");
495 }
496
497 py::object leaf = py::reinterpret_borrow<py::object>(*it);
498 agenda.push_back(f_leaf.is_none() ? std::move(leaf)
499 : f_leaf(std::move(leaf)));
500 ++it;
501 break;
502 }
503
504 case PyTreeKind::kNone:
505 case PyTreeKind::kTuple:
506 case PyTreeKind::kNamedTuple:
507 case PyTreeKind::kList:
508 case PyTreeKind::kDict:
509 case PyTreeKind::kCustom: {
510 if (agenda.size() < node.arity) {
511 throw std::logic_error("Too few elements for custom type.");
512 }
513 py::tuple tuple(node.arity);
514 for (int i = node.arity - 1; i >= 0; --i) {
515 tuple[i] = agenda.back();
516 agenda.pop_back();
517 }
518 agenda.push_back(f_node(tuple));
519 }
520 }
521 }
522 if (it != leaves.end()) {
523 throw std::invalid_argument("Too many leaves for PyTreeDef");
524 }
525 if (agenda.size() != 1) {
526 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
527 }
528 return std::move(agenda.back());
529 }
530
FromIterableTreeHelper(py::handle xs,absl::InlinedVector<PyTreeDef::Node,1>::const_reverse_iterator * it) const531 py::object PyTreeDef::FromIterableTreeHelper(
532 py::handle xs,
533 absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it) const {
534 if (*it == traversal_.rend()) {
535 throw std::invalid_argument("Tree structures did not match.");
536 }
537 const Node& node = **it;
538 ++*it;
539 if (node.kind == PyTreeKind::kLeaf) {
540 return py::reinterpret_borrow<py::object>(xs);
541 }
542 py::iterable iterable = py::reinterpret_borrow<py::iterable>(xs);
543 std::vector<py::object> ys;
544 ys.reserve(node.arity);
545 for (py::handle x : iterable) {
546 ys.push_back(py::reinterpret_borrow<py::object>(x));
547 }
548 if (ys.size() != node.arity) {
549 throw std::invalid_argument("Arity mismatch between trees");
550 }
551 for (int j = node.arity - 1; j >= 0; --j) {
552 ys[j] = FromIterableTreeHelper(ys[j], it);
553 }
554
555 return MakeNode(node, absl::MakeSpan(ys));
556 }
557
FromIterableTree(py::handle xs) const558 py::object PyTreeDef::FromIterableTree(py::handle xs) const {
559 auto it = traversal_.rbegin();
560 py::object out = FromIterableTreeHelper(xs, &it);
561 if (it != traversal_.rend()) {
562 throw std::invalid_argument("Tree structures did not match.");
563 }
564 return out;
565 }
566
Compose(const PyTreeDef & inner) const567 std::unique_ptr<PyTreeDef> PyTreeDef::Compose(const PyTreeDef& inner) const {
568 auto out = absl::make_unique<PyTreeDef>();
569 for (const Node& n : traversal_) {
570 if (n.kind == PyTreeKind::kLeaf) {
571 absl::c_copy(inner.traversal_, std::back_inserter(out->traversal_));
572 } else {
573 out->traversal_.push_back(n);
574 }
575 }
576 const auto& root = traversal_.back();
577 const auto& inner_root = inner.traversal_.back();
578 // TODO(tomhennigan): This should update all nodes in the traversal.
579 auto& out_root = out->traversal_.back();
580 out_root.num_nodes = (root.num_nodes - root.num_leaves) +
581 (inner_root.num_nodes * root.num_leaves);
582 out_root.num_leaves *= inner_root.num_leaves;
583 return out;
584 }
585
Tuple(const std::vector<PyTreeDef> & defs)586 /*static*/ std::unique_ptr<PyTreeDef> PyTreeDef::Tuple(
587 const std::vector<PyTreeDef>& defs) {
588 auto out = absl::make_unique<PyTreeDef>();
589 int num_leaves = 0;
590 for (const PyTreeDef& def : defs) {
591 absl::c_copy(def.traversal_, std::back_inserter(out->traversal_));
592 num_leaves += def.num_leaves();
593 }
594 Node node;
595 node.kind = PyTreeKind::kTuple;
596 node.arity = defs.size();
597 node.num_leaves = num_leaves;
598 node.num_nodes = out->traversal_.size() + 1;
599 out->traversal_.push_back(node);
600 return out;
601 }
602
Children() const603 std::vector<std::unique_ptr<PyTreeDef>> PyTreeDef::Children() const {
604 std::vector<std::unique_ptr<PyTreeDef>> children;
605 if (traversal_.empty()) {
606 return children;
607 }
608 Node const& root = traversal_.back();
609 children.resize(root.arity);
610 int pos = traversal_.size() - 1;
611 for (int i = root.arity - 1; i >= 0; --i) {
612 children[i] = absl::make_unique<PyTreeDef>();
613 const Node& node = traversal_.at(pos - 1);
614 if (pos < node.num_nodes) {
615 throw std::logic_error("children() walked off start of array");
616 }
617 std::copy(traversal_.begin() + pos - node.num_nodes,
618 traversal_.begin() + pos,
619 std::back_inserter(children[i]->traversal_));
620 pos -= node.num_nodes;
621 }
622 if (pos != 0) {
623 throw std::logic_error("pos != 0 at end of PyTreeDef::Children");
624 }
625 return children;
626 }
627
ToString() const628 std::string PyTreeDef::ToString() const {
629 std::vector<std::string> agenda;
630 for (const Node& node : traversal_) {
631 if (agenda.size() < node.arity) {
632 throw std::logic_error("Too few elements for container.");
633 }
634
635 std::string children =
636 absl::StrJoin(agenda.end() - node.arity, agenda.end(), ", ");
637 std::string representation;
638 switch (node.kind) {
639 case PyTreeKind::kLeaf:
640 agenda.push_back("*");
641 continue;
642 case PyTreeKind::kNone:
643 representation = "None";
644 break;
645 case PyTreeKind::kTuple:
646 // Tuples with only one element must have a trailing comma.
647 if (node.arity == 1) children += ",";
648 representation = absl::StrCat("(", children, ")");
649 break;
650 case PyTreeKind::kList:
651 representation = absl::StrCat("[", children, "]");
652 break;
653 case PyTreeKind::kDict: {
654 if (py::len(node.node_data) != node.arity) {
655 throw std::logic_error("Number of keys and entries does not match.");
656 }
657 std::string separator = "{";
658 auto child_iter = agenda.end() - node.arity;
659 for (const py::handle& key : node.node_data) {
660 absl::StrAppendFormat(&representation, "%s%s: %s", separator,
661 py::repr(key), *child_iter);
662 child_iter++;
663 separator = ", ";
664 }
665 representation += "}";
666 break;
667 }
668
669 case PyTreeKind::kNamedTuple:
670 case PyTreeKind::kCustom: {
671 std::string kind;
672 if (node.kind == PyTreeKind::kNamedTuple) {
673 kind = "namedtuple";
674 } else {
675 kind = static_cast<std::string>(py::str(node.custom->type));
676 }
677
678 std::string data;
679 if (node.node_data) {
680 data = absl::StrFormat("[%s]", py::str(node.node_data));
681 }
682 representation =
683 absl::StrFormat("CustomNode(%s%s, [%s])", kind, data, children);
684 break;
685 }
686 }
687 agenda.erase(agenda.end() - node.arity, agenda.end());
688 agenda.push_back(std::move(representation));
689 }
690 if (agenda.size() != 1) {
691 throw std::logic_error("PyTreeDef traversal did not yield a singleton.");
692 }
693 return absl::StrCat("PyTreeDef(", agenda.back(), ")");
694 }
695
BuildPytreeSubmodule(py::module & m)696 void BuildPytreeSubmodule(py::module& m) {
697 py::module pytree = m.def_submodule("pytree", "Python tree library");
698 pytree.def("flatten", &PyTreeDef::Flatten, py::arg("tree"),
699 py::arg("leaf_predicate") = absl::nullopt);
700 pytree.def("tuple", &PyTreeDef::Tuple);
701 pytree.def("all_leaves", &PyTreeDef::AllLeaves);
702
703 py::class_<PyTreeDef>(m, "PyTreeDef")
704 .def("unflatten",
705 static_cast<pybind11::object (PyTreeDef::*)(
706 pybind11::iterable leaves) const>(&PyTreeDef::Unflatten))
707 .def("flatten_up_to", &PyTreeDef::FlattenUpTo)
708 .def("compose", &PyTreeDef::Compose)
709 .def("walk", &PyTreeDef::Walk)
710 .def("from_iterable_tree", &PyTreeDef::FromIterableTree)
711 .def("children", &PyTreeDef::Children)
712 .def_property_readonly("num_leaves", &PyTreeDef::num_leaves)
713 .def_property_readonly("num_nodes", &PyTreeDef::num_nodes)
714 .def("__repr__", &PyTreeDef::ToString)
715 .def("__eq__",
716 [](const PyTreeDef& a, const PyTreeDef& b) { return a == b; })
717 .def("__ne__",
718 [](const PyTreeDef& a, const PyTreeDef& b) { return a != b; })
719 .def("__hash__",
720 [](const PyTreeDef& t) { return absl::Hash<PyTreeDef>()(t); });
721
722 pytree.def("register_node", [](py::object type, py::function to_iterable,
723 py::function from_iterable) {
724 return PyTreeTypeRegistry::Register(type, to_iterable, from_iterable);
725 });
726 }
727
728 } // namespace xla
729