1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24
25 #include <memory>
26 #include <stdexcept>
27 #include <utility>
28 #include <vector>
29
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/hash/hash.h"
32 #include "absl/memory/memory.h"
33 #include "pybind11/pybind11.h"
34 #include "pybind11/pytypes.h"
35 #include "pybind11/stl.h"
36
37 namespace xla {
38
39 // Registry of custom node types.
40 class CustomNodeRegistry {
41 public:
42 struct Registration {
43 // The Python type object, used to identify the type.
44 pybind11::object type;
45 // A function with signature: object -> (iterable, aux_data)
46 pybind11::function to_iterable;
47 // A function with signature: (aux_data, iterable) -> object
48 pybind11::function from_iterable;
49 };
50
51 // Registers a new custom type. Objects of `type` will be treated as container
52 // node types in PyTrees.
53 static void Register(pybind11::object type, pybind11::function to_iterable,
54 pybind11::function from_iterable);
55
56 // Finds the custom type registration for `type`. Returns nullptr if none
57 // exists.
58 static const Registration* Lookup(pybind11::handle type);
59
60 private:
61 static CustomNodeRegistry* Singleton();
62
63 struct TypeHash {
operatorTypeHash64 size_t operator()(const pybind11::object& t) const {
65 return pybind11::hash(t);
66 }
67 };
68 struct TypeEq {
operatorTypeEq69 bool operator()(const pybind11::object& a,
70 const pybind11::object& b) const {
71 return a.equal(b);
72 }
73 };
74 absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
75 TypeEq>
76 registrations_;
77 };
78
79 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
80 // Python values, where the interior nodes are tuples, lists, dictionaries, or
81 // user-defined containers, and the leaves are other objects.
82 class PyTreeDef {
83 public:
84 PyTreeDef() = default;
85
86 // Flattens a Pytree into a list of leaves and a PyTreeDef.
87 static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
88 Flatten(pybind11::handle x,
89 absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
90
91 // Recursive helper used to implement Flatten().
92 void FlattenInto(
93 pybind11::handle handle, std::vector<pybind11::object>& leaves,
94 absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
95
96 // Tests whether the given list is a flat list of leaves.
97 static bool AllLeaves(const pybind11::iterable& x);
98
99 // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
100 // the tree-structure of 'x'. For example, if we flatten a value
101 // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
102 // list of leaves [1, (2, 3), {"foo": 4}].
103 pybind11::list FlattenUpTo(pybind11::handle x) const;
104
105 // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
106 pybind11::object Unflatten(pybind11::iterable leaves) const;
107
108 // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
109 // `inner`.
110 std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
111
112 // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
113 static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
114
115 std::vector<std::unique_ptr<PyTreeDef>> Children() const;
116
117 // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
118 // f_node to each container node.
119 // TODO(phawkins): use flattening everywhere instead and delete this method.
120 pybind11::object Walk(const pybind11::function& f_node,
121 pybind11::handle f_leaf,
122 pybind11::iterable leaves) const;
123
124 // Given a tree of iterables with the same node/leaf structure as this PyTree,
125 // build the corresponding PyTree.
126 // TODO(phawkins): use flattening everywhere instead and delete this method.
127 pybind11::object FromIterableTree(pybind11::handle xs) const;
128
num_leaves()129 int num_leaves() const {
130 if (traversal_.empty()) {
131 return 0;
132 }
133 return traversal_.back().num_leaves;
134 }
135
num_nodes()136 int num_nodes() const { return traversal_.size(); }
137
138 size_t Hash() const;
139
140 bool operator==(const PyTreeDef& other) const;
141 bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
142
143 std::string ToString() const;
144
145 private:
146 enum class Kind {
147 kLeaf, // An opaque leaf node
148 kNone, // None.
149 kTuple, // A tuple
150 kNamedTuple, // A collections.namedtuple
151 kList, // A list
152 kDict, // A dict
153 kCustom, // A custom type.
154 };
155
156 struct Node {
157 Kind kind = Kind::kLeaf;
158
159 // Arity for non-kLeaf types.
160 int arity = 0;
161
162 // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
163 // object. For a kDict, contains a sorted list of keys. For a kCustom type,
164 // contains the auxiliary data returned by the `to_iterable` function.
165 pybind11::object node_data;
166
167 const CustomNodeRegistry::Registration* custom = nullptr;
168
169 // Number of leaf nodes in the subtree rooted at this node.
170 int num_leaves = 0;
171
172 // Number of leaf and interior nodes in the subtree rooted at this node.
173 int num_nodes = 0;
174 };
175 template <typename H>
176 friend H AbslHashValue(H h, const Node& n);
177
178 template <typename H>
179 friend H AbslHashValue(H h, const PyTreeDef& t);
180
181 // Helper that manufactures an instance of a node given its children.
182 static pybind11::object MakeNode(const Node& node,
183 absl::Span<pybind11::object> children);
184
185 // Recursive helper used to implement FromIterableTree()
186 pybind11::object FromIterableTreeHelper(
187 pybind11::handle xs,
188 std::vector<PyTreeDef::Node>::const_reverse_iterator* it) const;
189
190 // Computes the node kind of a given Python object.
191 static Kind GetKind(const pybind11::handle& obj,
192 CustomNodeRegistry::Registration const** custom);
193
194 // Nodes, in a post-order traversal. We use an ordered traversal to minimize
195 // allocations, and post-order corresponds to the order we need to rebuild the
196 // tree structure.
197 std::vector<Node> traversal_;
198 };
199
200 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)201 H AbslHashValue(H h, const PyTreeDef::Node& n) {
202 h = H::combine(std::move(h), n.kind, n.arity, n.custom);
203 return h;
204 }
205
206 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)207 H AbslHashValue(H h, const PyTreeDef& t) {
208 return H::combine_contiguous(std::move(h), t.traversal_.data(),
209 t.traversal_.size());
210 }
211
212 void BuildPytreeSubmodule(pybind11::module& m);
213
214 } // namespace xla
215
216 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
217