• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
18 
19 // See https://jax.readthedocs.io/en/latest/pytrees.html for the documentation
20 // about pytree.
21 
22 // Caution: this code uses exceptions. The exception use is local to the
23 // binding code and the idiomatic way to emit Python exceptions.
24 
25 #include <memory>
26 #include <stdexcept>
27 #include <utility>
28 #include <vector>
29 
30 #include "absl/container/flat_hash_map.h"
31 #include "absl/container/inlined_vector.h"
32 #include "absl/hash/hash.h"
33 #include "absl/memory/memory.h"
34 #include "pybind11/pybind11.h"
35 #include "pybind11/pytypes.h"
36 #include "pybind11/stl.h"
37 
38 namespace xla {
39 
40 enum class PyTreeKind {
41   kLeaf,        // An opaque leaf node
42   kNone,        // None.
43   kTuple,       // A tuple
44   kNamedTuple,  // A collections.namedtuple
45   kList,        // A list
46   kDict,        // A dict
47   kCustom,      // A custom type.
48 };
49 
50 // Registry of custom node types.
51 class PyTreeTypeRegistry {
52  public:
53   struct Registration {
54     PyTreeKind kind;
55 
56     // The following values are populated for custom types.
57     // The Python type object, used to identify the type.
58     pybind11::object type;
59     // A function with signature: object -> (iterable, aux_data)
60     pybind11::function to_iterable;
61     // A function with signature: (aux_data, iterable) -> object
62     pybind11::function from_iterable;
63   };
64 
65   // Registers a new custom type. Objects of `type` will be treated as container
66   // node types in PyTrees.
67   static void Register(pybind11::object type, pybind11::function to_iterable,
68                        pybind11::function from_iterable);
69 
70   // Finds the custom type registration for `type`. Returns nullptr if none
71   // exists.
72   static const Registration* Lookup(pybind11::handle type);
73 
74  private:
75   static PyTreeTypeRegistry* Singleton();
76 
77   struct TypeHash {
78     using is_transparent = void;
operatorTypeHash79     size_t operator()(const pybind11::object& t) const {
80       return absl::Hash<void*>()(t.ptr());
81     }
operatorTypeHash82     size_t operator()(const pybind11::handle& t) const {
83       return absl::Hash<void*>()(t.ptr());
84     }
85   };
86   struct TypeEq {
87     using is_transparent = void;
operatorTypeEq88     bool operator()(const pybind11::object& a,
89                     const pybind11::object& b) const {
90       return a.ptr() == b.ptr();
91     }
operatorTypeEq92     bool operator()(const pybind11::object& a,
93                     const pybind11::handle& b) const {
94       return a.ptr() == b.ptr();
95     }
96   };
97   absl::flat_hash_map<pybind11::object, std::unique_ptr<Registration>, TypeHash,
98                       TypeEq>
99       registrations_;
100 };
101 
102 // A PyTreeDef describes the tree structure of a PyTree. A PyTree is a tree of
103 // Python values, where the interior nodes are tuples, lists, dictionaries, or
104 // user-defined containers, and the leaves are other objects.
105 class PyTreeDef {
106  public:
107   PyTreeDef() = default;
108 
109   // Flattens a Pytree into a list of leaves and a PyTreeDef.
110   // Returns references to the flattened objects, which might be temporary
111   // objects in the case of custom pytype handlers.
112   static std::pair<std::vector<pybind11::object>, std::unique_ptr<PyTreeDef>>
113   Flatten(pybind11::handle x,
114           absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
115 
116   // Recursive helper used to implement Flatten().
117   void FlattenInto(
118       pybind11::handle handle, std::vector<pybind11::object>& leaves,
119       absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
120   void FlattenInto(
121       pybind11::handle handle, absl::InlinedVector<pybind11::object, 2>& leaves,
122       absl::optional<pybind11::function> leaf_predicate = absl::nullopt);
123 
124   // Tests whether the given list is a flat list of leaves.
125   static bool AllLeaves(const pybind11::iterable& x);
126 
127   // Flattens a Pytree up to this PyTreeDef. 'this' must be a tree prefix of
128   // the tree-structure of 'x'. For example, if we flatten a value
129   // [(1, (2, 3)), {"foo": 4}] with a treedef [(*, *), *], the result is the
130   // list of leaves [1, (2, 3), {"foo": 4}].
131   pybind11::list FlattenUpTo(pybind11::handle x) const;
132 
133   // Returns an unflattened PyTree given an iterable of leaves and a PyTreeDef.
134   pybind11::object Unflatten(pybind11::iterable leaves) const;
135   pybind11::object Unflatten(absl::Span<const pybind11::object> leaves) const;
136 
137   // Composes two PyTreeDefs, replacing the leaves of this tree with copies of
138   // `inner`.
139   std::unique_ptr<PyTreeDef> Compose(const PyTreeDef& inner) const;
140 
141   // Makes a Tuple PyTreeDef out of a vector of PyTreeDefs.
142   static std::unique_ptr<PyTreeDef> Tuple(const std::vector<PyTreeDef>& defs);
143 
144   std::vector<std::unique_ptr<PyTreeDef>> Children() const;
145 
146   // Maps a function over a PyTree structure, applying f_leaf to each leaf, and
147   // f_node to each container node.
148   // TODO(phawkins): use flattening everywhere instead and delete this method.
149   pybind11::object Walk(const pybind11::function& f_node,
150                         pybind11::handle f_leaf,
151                         pybind11::iterable leaves) const;
152 
153   // Given a tree of iterables with the same node/leaf structure as this PyTree,
154   // build the corresponding PyTree.
155   // TODO(phawkins): use flattening everywhere instead and delete this method.
156   pybind11::object FromIterableTree(pybind11::handle xs) const;
157 
num_leaves()158   int num_leaves() const {
159     if (traversal_.empty()) {
160       return 0;
161     }
162     return traversal_.back().num_leaves;
163   }
164 
num_nodes()165   int num_nodes() const { return traversal_.size(); }
166 
167   size_t Hash() const;
168 
169   bool operator==(const PyTreeDef& other) const;
170   bool operator!=(const PyTreeDef& other) const { return !(*this == other); }
171 
172   std::string ToString() const;
173 
174  private:
175   struct Node {
176     PyTreeKind kind = PyTreeKind::kLeaf;
177 
178     // Arity for non-kLeaf types.
179     int arity = 0;
180 
181     // Kind-specific auxiliary data. For a kNamedTuple, contains the tuple type
182     // object. For a kDict, contains a sorted list of keys. For a kCustom type,
183     // contains the auxiliary data returned by the `to_iterable` function.
184     pybind11::object node_data;
185 
186     const PyTreeTypeRegistry::Registration* custom = nullptr;
187 
188     // Number of leaf nodes in the subtree rooted at this node.
189     int num_leaves = 0;
190 
191     // Number of leaf and interior nodes in the subtree rooted at this node.
192     int num_nodes = 0;
193   };
194   template <typename H>
195   friend H AbslHashValue(H h, const Node& n);
196 
197   template <typename H>
198   friend H AbslHashValue(H h, const PyTreeDef& t);
199 
200   // Helper that manufactures an instance of a node given its children.
201   static pybind11::object MakeNode(const Node& node,
202                                    absl::Span<pybind11::object> children);
203 
204   // Recursive helper used to implement FromIterableTree()
205   pybind11::object FromIterableTreeHelper(
206       pybind11::handle xs,
207       absl::InlinedVector<PyTreeDef::Node, 1>::const_reverse_iterator* it)
208       const;
209 
210   // Computes the node kind of a given Python object.
211   static PyTreeKind GetKind(const pybind11::handle& obj,
212                             PyTreeTypeRegistry::Registration const** custom);
213 
214   template <typename T>
215   void FlattenIntoImpl(
216       pybind11::handle handle, T& leaves,
217       const absl::optional<pybind11::function>& leaf_predicate);
218 
219   template <typename T>
220   pybind11::object UnflattenImpl(T leaves) const;
221 
222   // Nodes, in a post-order traversal. We use an ordered traversal to minimize
223   // allocations, and post-order corresponds to the order we need to rebuild the
224   // tree structure.
225   absl::InlinedVector<Node, 1> traversal_;
226 };
227 
228 template <typename H>
AbslHashValue(H h,const PyTreeDef::Node & n)229 H AbslHashValue(H h, const PyTreeDef::Node& n) {
230   h = H::combine(std::move(h), n.kind, n.arity, n.custom);
231   return h;
232 }
233 
234 template <typename H>
AbslHashValue(H h,const PyTreeDef & t)235 H AbslHashValue(H h, const PyTreeDef& t) {
236   return H::combine_contiguous(std::move(h), t.traversal_.data(),
237                                t.traversal_.size());
238 }
239 
240 void BuildPytreeSubmodule(pybind11::module& m);
241 
242 }  // namespace xla
243 
244 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PYTREE_H_
245