1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24
25 #include <memory>
26 #include <stdexcept>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/hash/hash.h"
33 #include "absl/memory/memory.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37
38 namespace xla {
39
40 enum class PyTreeKind {
41 kLeaf, // An opaque leaf node
42 kNone, // None.
43 kTuple, // A tuple
44 kNamedTuple, // A collections.namedtuple
45 kList, // A list
46 kDict, // A dict
47 kCustom, // A custom type.
48 };
49
50 // Registry of custom node types.
51 class PyTreeTypeRegistry {
52 public:
53 struct Registration {
54 PyTreeKind kind;
55
56 // The following values are populated for custom types.
57 // The Python type object, used to identify the type.
58 pybind11::object type;
59 // A function with signature: object -> (iterable, aux_data)
60 pybind11::function to_iterable;
61 // A function with signature: (aux_data, iterable) -> object
62 pybind11::function from_iterable;
63 };
64
65 // Registers a new custom type. Objects of `type` will be treated as container
66 // node types in PyTrees.
67 static void Register(pybind11::object type, pybind11::function to_iterable,
68 pybind11::function from_iterable);
69
70 // Finds the custom type registration for `type`. Returns nullptr if none
71 // exists.
72 static const Registration* Lookup(pybind11::handle type);
73
74 private:
75 static PyTreeTypeRegistry* Singleton();
76
77 struct TypeHash {
78 using is_transparent = void;
operatorTypeHash79 size_t operator()(const pybind11::object& t) const {
80 return absl::Hash<void*>()(t.ptr());
81 }
operatorTypeHash82 size_t operator()(const pybind11::handle& t) const {
83 return absl::Hash<void*>()(t.ptr());
84 }
85 };
86 struct TypeEq {
87 using is_transparent = void;
operatorTypeEq88 bool operator()(const pybind11::object& a,
89 const pybind11::object& b) const {
90 return a.ptr() == b.ptr();
91 }
operatorTypeEq92 bool operator()(const pybind11::object& a,
93 const pybind11::handle& b) const {
94 return a.ptr() == b.ptr();
95 }
96 };
97 absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
98 TypeEq>
99 registrations_;
100 };
101
102 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
103 // Python values, where the interior nodes are tuples, lists, dictionaries, or
104 // user-defined containers, and the leaves are other objects.
105 class PyTreeDef {
106 public:
107 PyTreeDef() = default;
108
109 // Flattens a Pytree into a list of leaves and a PyTreeDef.
110 // Returns references to the flattened objects, which might be temporary
111 // objects in the case of custom pytype handlers.
112 static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
113 Flatten(pybind11::handle x,
114 absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
115
116 // Recursive helper used to implement Flatten().
117 void FlattenInto(
118 pybind11::handle handle, std::vector<pybind11::object>& leaves,
119 absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
120 void FlattenInto(
121 pybind11::handle handle, absl::InlinedVector<pybind11::object, 2>& leaves,
122 absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
123
124 // Tests whether the given list is a flat list of leaves.
125 static bool AllLeaves(const pybind11::iterable& x);
126
127 // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
128 // the tree-structure of 'x'. For example, if we flatten a value
129 // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
130 // list of leaves [1, (2, 3), {"foo": 4}].
131 pybind11::list FlattenUpTo(pybind11::handle x) const;
132
133 // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
134 pybind11::object Unflatten(pybind11::iterable leaves) const;
135 pybind11::object Unflatten(absl::Span<const pybind11::object> leaves) const;
136
137 // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
138 // `inner`.
139 std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
140
141 // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
142 static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
143
144 std::vector<std::unique_ptr<PyTreeDef>> Children() const;
145
146 // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
147 // f_node to each container node.
148 // TODO(phawkins): use flattening everywhere instead and delete this method.
149 pybind11::object Walk(const pybind11::function& f_node,
150 pybind11::handle f_leaf,
151 pybind11::iterable leaves) const;
152
153 // Given a tree of iterables with the same node/leaf structure as this PyTree,
154 // build the corresponding PyTree.
155 // TODO(phawkins): use flattening everywhere instead and delete this method.
156 pybind11::object FromIterableTree(pybind11::handle xs) const;
157
num_leaves()158 int num_leaves() const {
159 if (traversal_.empty()) {
160 return 0;
161 }
162 return traversal_.back().num_leaves;
163 }
164
num_nodes()165 int num_nodes() const { return traversal_.size(); }
166
167 size_t Hash() const;
168
169 bool operator==(const PyTreeDef& other) const;
170 bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
171
172 std::string ToString() const;
173
174 private:
175 struct Node {
176 PyTreeKind kind = PyTreeKind::kLeaf;
177
178 // Arity for non-kLeaf types.
179 int arity = 0;
180
181 // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
182 // object. For a kDict, contains a sorted list of keys. For a kCustom type,
183 // contains the auxiliary data returned by the `to_iterable` function.
184 pybind11::object node_data;
185
186 const PyTreeTypeRegistry::Registration* custom = nullptr;
187
188 // Number of leaf nodes in the subtree rooted at this node.
189 int num_leaves = 0;
190
191 // Number of leaf and interior nodes in the subtree rooted at this node.
192 int num_nodes = 0;
193 };
194 template <typename H>
195 friend H AbslHashValue(H h, const Node& n);
196
197 template <typename H>
198 friend H AbslHashValue(H h, const PyTreeDef& t);
199
200 // Helper that manufactures an instance of a node given its children.
201 static pybind11::object MakeNode(const Node& node,
202 absl::Span<pybind11::object> children);
203
204 // Recursive helper used to implement FromIterableTree()
205 pybind11::object FromIterableTreeHelper(
206 pybind11::handle xs,
207 absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it)
208 const;
209
210 // Computes the node kind of a given Python object.
211 static PyTreeKind GetKind(const pybind11::handle& obj,
212 PyTreeTypeRegistry::Registration const** custom);
213
214 template <typename T>
215 void FlattenIntoImpl(
216 pybind11::handle handle, T& leaves,
217 const absl::optional<pybind11::function>& leaf_predicate);
218
219 template <typename T>
220 pybind11::object UnflattenImpl(T leaves) const;
221
222 // Nodes, in a post-order traversal. We use an ordered traversal to minimize
223 // allocations, and post-order corresponds to the order we need to rebuild the
224 // tree structure.
225 absl::InlinedVector<Node, 1> traversal_;
226 };
227
228 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)229 H AbslHashValue(H h, const PyTreeDef::Node& n) {
230 h = H::combine(std::move(h), n.kind, n.arity, n.custom);
231 return h;
232 }
233
234 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)235 H AbslHashValue(H h, const PyTreeDef& t) {
236 return H::combine_contiguous(std::move(h), t.traversal_.data(),
237 t.traversal_.size());
238 }
239
240 void BuildPytreeSubmodule(pybind11::module& m);
241
242 } // namespace xla
243
244 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
245