1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace xla {
38
39 namespace internal {
40
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44 // Data corresponding to this node.
45 std::pair<ShapeIndex, T> data;
46
47 bool is_leaf = true;
48
ShapeTreeNodeShapeTreeNode49 explicit ShapeTreeNode(ShapeIndex index)
50 : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51 ShapeTreeNode(ShapeIndex index, T data)
52 : data(std::move(index), std::move(data)) {}
53 };
54
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57 // Index of the node in the ShapeTreeNode vector.
58 uint32 index;
59 // Index of the first child in a IndexTableEntry vector. In the index
60 // table all children entries for a given node will be placed next to each
61 // other. This allows us to use a single field to index them.
62 uint32 children_start;
63 #ifndef NDEBUG
64 // Number of children, used for bounds checking.
65 uint32 children_count;
66 #endif
67 };
68
69 } // namespace internal
70
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73 template <typename ContainerType, typename IteratorType, typename ValueType>
74 class ShapeTreeLeafIterator;
75
76 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
77 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
78 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
79 // type T.
80 //
81 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
82 // ShapeTree is an identically structured tree with data elements of type T at
83 // every node. I.e. the root is a tuple by definition, all interior nodes are
84 // also tuples, and all leaves are arrays.
85 //
86 // Like the Shape data structure, this is a tree and tuple elements cannot be
87 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
88 // object.
89 //
90 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
91 // it's helpful not to copy a Shape just to make a ShapeTree. In these cases,
92 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's
93 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
94 // before its ShapeTree goes away.
95 template <typename T>
96 class ShapeTree {
97 public:
98 using Node = internal::ShapeTreeNode<T>;
99 using Index = internal::IndexTableEntry;
100
101 // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()102 ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
103
104 // Create ShapeTree with the given shape, and default-constructed T values for
105 // all nodes.
106 //
107 // The version that takes a pointer may be cheaper because it doesn't require
108 // any Shape copies, but then it's up to you to ensure that the pointer stays
109 // alive longer than this ShapeTree.
110 explicit ShapeTree(Shape shape);
111 explicit ShapeTree(const Shape* shape);
112 explicit ShapeTree(const std::shared_ptr<Shape>& shape);
113
114 // Create ShapeTree with the given shape, and init_value for all nodes.
115 ShapeTree(Shape shape, const T& init_value);
116 ShapeTree(const Shape* shape, const T& init_value);
117 ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
118
119 // Returns the data element associated with the array in the shape at the
120 // given index (see ShapeUtil::GetSubshape for how indexes are defined).
121 const T& element(ShapeIndexView index) const;
122 T* mutable_element(ShapeIndexView index);
123
124 // Return the shape represented with this ShapeTree.
shape()125 const Shape& shape() const { return *shape_; }
126
127 // A ShapeTree object can own the underlying Shape pointer (via the
128 // shape_storage_ member), or can point to a Shape object owned by the caller.
129 // This API replaces the underlying Shape object to the one supplied by the
130 // caller, whom must ensure the object remain valid for the whole lifetime of
131 // this ShapeTree object, and also that the Shape is consistent with it.
replace_shape_ptr(const Shape * shape)132 void replace_shape_ptr(const Shape* shape) {
133 if (shape_storage_ != nullptr) {
134 DCHECK_EQ(*shape, *shape_storage_);
135 shape_storage_ = nullptr;
136 }
137 shape_ = shape;
138 }
139
140 // Returns true if the node at the given index is a leaf node (an array
141 // shape).
IsLeaf(ShapeIndexView index)142 bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
143
144 ShapeTree(const ShapeTree&) = default;
145 ShapeTree& operator=(const ShapeTree&) = default;
146 ShapeTree(ShapeTree&&) = default;
147 ShapeTree& operator=(ShapeTree&& other) = default;
148
149 // iterator implements a bidirectional_iterator with
150 // value_type = std::pair<ShapeIndex, T>.
151 //
152 // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
153 using iterator =
154 ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
155 std::pair<ShapeIndex, T>>;
156 using const_iterator =
157 ShapeTreeIterator<const std::vector<Node>,
158 typename std::vector<Node>::const_iterator,
159 const std::pair<ShapeIndex, T>>;
160 using reverse_iterator = std::reverse_iterator<iterator>;
161 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
162
163 using leaf_iterator =
164 ShapeTreeLeafIterator<std::vector<Node>,
165 typename std::vector<Node>::iterator,
166 std::pair<ShapeIndex, T>>;
167 using const_leaf_iterator =
168 ShapeTreeLeafIterator<const std::vector<Node>,
169 typename std::vector<Node>::const_iterator,
170 const std::pair<ShapeIndex, T>>;
171 using reverse_leaf_iterator = std::reverse_iterator<leaf_iterator>;
172 using const_reverse_leaf_iterator =
173 std::reverse_iterator<const_leaf_iterator>;
174
175 // begin/end for iterating over all nodes.
begin()176 iterator begin() { return iterator(&nodes_, nodes_.begin()); }
end()177 iterator end() { return iterator(&nodes_, nodes_.end()); }
begin()178 const_iterator begin() const {
179 return const_iterator(&nodes_, nodes_.begin());
180 }
end()181 const_iterator end() const { return const_iterator(&nodes_, nodes_.end()); }
182
183 // rbegin/rend for iterating over all nodes in reverse.
rbegin()184 reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()185 reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()186 const_reverse_iterator rbegin() const {
187 return const_reverse_iterator(end());
188 }
rend()189 const_reverse_iterator rend() const {
190 return const_reverse_iterator(begin());
191 }
192
193 // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
194 // children).
leaf_begin()195 leaf_iterator leaf_begin() { return leaf_iterator(&nodes_, nodes_.begin()); }
leaf_end()196 leaf_iterator leaf_end() { return leaf_iterator(&nodes_, nodes_.end()); }
leaf_begin()197 const_leaf_iterator leaf_begin() const {
198 return const_leaf_iterator(&nodes_, nodes_.begin());
199 }
leaf_end()200 const_leaf_iterator leaf_end() const {
201 return const_leaf_iterator(&nodes_, nodes_.end());
202 }
203 // range-based iterator for leaf_begin()/leaf_end().
leaves()204 tensorflow::gtl::iterator_range<leaf_iterator> leaves() {
205 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
206 }
leaves()207 tensorflow::gtl::iterator_range<const_leaf_iterator> leaves() const {
208 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
209 }
210
leaf_rbegin()211 reverse_leaf_iterator leaf_rbegin() {
212 return reverse_leaf_iterator(leaf_end());
213 }
leaf_rend()214 reverse_leaf_iterator leaf_rend() {
215 return reverse_leaf_iterator(leaf_begin());
216 }
leaf_rbegin()217 const_reverse_leaf_iterator leaf_rbegin() const {
218 return const_reverse_leaf_iterator(leaf_end());
219 }
leaf_rend()220 const_reverse_leaf_iterator leaf_rend() const {
221 return const_reverse_leaf_iterator(leaf_begin());
222 }
223
224 // Returns an iterator pointing to the given ShapeIndex.
225 // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)226 iterator find(ShapeIndexView index) {
227 Node* element = Lookup(index);
228 auto element_iter = nodes_.begin() + (element - &nodes_[0]);
229 return iterator(&nodes_, element_iter);
230 }
find(ShapeIndexView index)231 const_iterator find(ShapeIndexView index) const {
232 const Node* element = Lookup(index);
233 auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
234 return const_iterator(&nodes_, element_iter);
235 }
236
237 // Returns the number of leaf nodes in the tree.
leaf_count()238 int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
239
240 // Recursively traverses the shape and calls the given function at each
241 // element. The function has the following arguments:
242 //
243 // Fn : A callable of type void(const ShapeIndex& index, const T& data)
244 // (or compatible).
245 // index : the index of the element in the shape. See ShapeUtil::GetSubshape
246 // for definition of index.
247 // data : The data value at this element.
248 template <typename Fn>
249 void ForEachElement(const Fn& func) const;
250
251 // Like ForEachElement, but the callable has type
252 //
253 // void (const ShapeIndex& index, T* data).
254 //
255 template <typename Fn>
256 void ForEachMutableElement(const Fn& func);
257
258 // Like ForEach(Mutable)Element, but the callable returns a Status instead of
259 // void. The first non-OK return value is returned by the ForEach* function.
260 template <typename Fn>
261 Status ForEachElementWithStatus(const Fn& func) const;
262 template <typename Fn>
263 Status ForEachMutableElementWithStatus(const Fn& func);
264
265 // Maps each element to generate a new tree with the same shape.
266 template <typename U>
Map(const std::function<U (const T &)> & func)267 ShapeTree<U> Map(const std::function<U(const T&)>& func) {
268 ShapeTree<U> result(shape_storage_);
269 ForEachElement([&](const ShapeIndex& index, const T& t) {
270 *result.mutable_element(index) = func(t);
271 });
272 return result;
273 }
274
275 template <typename U>
Map(const std::function<U (T *)> & func)276 ShapeTree<U> Map(const std::function<U(T*)>& func) {
277 ShapeTree<U> result(shape_storage_);
278 ForEachMutableElement([&](const ShapeIndex& index, T* t) {
279 *result.mutable_element(index) = func(t);
280 });
281 return result;
282 }
283
284 // Copy the subtree of values from 'other' rooted at ShapeIndex
285 // 'source_base_index' into the subtree of value in this ShapeTree rooted at
286 // 'target_base_index'.
287 //
288 // Precondition: The subshape of other.shape() at index source_base_index must
289 // be compatible with the subshape of shape() at index target_base_index.
290 void CopySubtreeFrom(const ShapeTree<T>& other,
291 const ShapeIndex& source_base_index,
292 const ShapeIndex& target_base_index);
293
294 StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
295
296 bool operator==(const ShapeTree<T>& other) const;
297 bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
298
299 private:
300 // Initialize node->children based on 'shape'. All children are assigned the
301 // the given 'init_value'.
302 void InitChildren(const Shape& shape, const T& init_value, Node* node,
303 Index* index);
304
305 // Initialize node->children based on 'shape'. All children have
306 // default-constructed data values.
307 void InitChildren(const Shape& shape, Node* node, Index* index);
308
309 // Returns the number of subshapes, including interior nodes, in shape.
310 int64 CountSubshapes(const Shape& shape);
311
312 // Helpers for traversing the shape via ForEachElement. The helpers
313 // recursively traverse the subtree rooted at "index" (defined as in
314 // ShapeUtil::GetSubshape).
315 template <typename Fn>
316 static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
317 template <typename Fn>
318 static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
319
320 // Return the tree node at the given index.
321 Node* Lookup(ShapeIndexView index);
322 const Node* Lookup(ShapeIndexView index) const;
323
324 // The nodes in this shape tree.
325 std::vector<Node> nodes_;
326
327 // Index table for node lookups.
328 std::vector<Index> index_table_;
329
330 // If we own our Shape, this field contains it, and shape_ is a pointer into
331 // here. Otherwise if we don't own our shape, this is nullptr.
332 std::shared_ptr<Shape> shape_storage_;
333
334 // The XLA shape mirrored in this ShapeTree. This is either
335 // shape_storage_.get() or the Shape pointer passed to our constructor.
336 const Shape* shape_;
337 };
338
339 // Internal iterator that performs a pre-order walk. This is cheap to copy.
340 // The iterator value_type is equivalent to a
341 // std::pair<ShapeIndex,T>&, similar to std::map.
342 template <typename ContainerType, typename IteratorType, typename ValueType>
343 class ShapeTreeIterator
344 : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
345 public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node)346 ShapeTreeIterator(ContainerType* nodes, IteratorType node)
347 : nodes_(nodes), node_(std::move(node)) {}
348
349 ShapeTreeIterator& operator++() {
350 ++node_;
351 return *this;
352 }
353 ShapeTreeIterator operator++(int) {
354 auto i = *this;
355 ++(*this);
356 return i;
357 }
358
359 ShapeTreeIterator& operator--() {
360 --node_;
361 return *this;
362 }
363 ShapeTreeIterator operator--(int) {
364 auto i = *this;
365 --(*this);
366 return i;
367 }
368
369 bool operator==(const ShapeTreeIterator& other) const {
370 return node_ == other.node_;
371 }
372 bool operator!=(const ShapeTreeIterator& other) const {
373 return node_ != other.node_;
374 }
375 ValueType& operator*() const { return node_->data; }
376 ValueType* operator->() const { return &node_->data; }
377
378 private:
379 ContainerType* nodes_;
380 IteratorType node_;
381 };
382
383 // Internal iterator that performs a pre-order walk of the leaves. This is cheap
384 // to copy. The iterator value_type is equivalent to a std::pair<ShapeIndex,T>&,
385 // similar to std::map.
386 template <typename ContainerType, typename IteratorType, typename ValueType>
387 class ShapeTreeLeafIterator
388 : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
389 public:
ShapeTreeLeafIterator(ContainerType * nodes,IteratorType node)390 ShapeTreeLeafIterator(ContainerType* nodes, IteratorType node)
391 : nodes_(nodes), node_(std::move(node)) {
392 while (node_ != nodes_->end() && !node_->is_leaf) {
393 ++node_;
394 }
395 }
396
397 ShapeTreeLeafIterator& operator++() {
398 ++node_;
399 while (node_ != nodes_->end() && !node_->is_leaf) {
400 ++node_;
401 }
402 return *this;
403 }
404 ShapeTreeLeafIterator operator++(int) {
405 auto i = *this;
406 ++(*this);
407 return i;
408 }
409
410 ShapeTreeLeafIterator& operator--() {
411 --node_;
412 while (node_ > nodes_->begin() && !node_->is_leaf) {
413 --node_;
414 }
415 return *this;
416 }
417 ShapeTreeLeafIterator operator--(int) {
418 auto i = *this;
419 --(*this);
420 return i;
421 }
422
423 bool operator==(const ShapeTreeLeafIterator& other) const {
424 return node_ == other.node_;
425 }
426 bool operator!=(const ShapeTreeLeafIterator& other) const {
427 return node_ != other.node_;
428 }
429 ValueType& operator*() const { return node_->data; }
430 ValueType* operator->() const { return &node_->data; }
431
432 private:
433 ContainerType* nodes_;
434 IteratorType node_;
435 };
436
437 template <typename T>
CountSubshapes(const Shape & shape)438 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
439 int64 current_count = 1;
440 if (shape.IsTuple()) {
441 int64 count = ShapeUtil::TupleElementCount(shape);
442 for (int i = 0; i < count; ++i) {
443 current_count += CountSubshapes(shape.tuple_shapes(i));
444 }
445 }
446 return current_count;
447 }
448
449 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)450 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
451 Node* node, Index* index) {
452 if (shape.IsTuple()) {
453 const int64 size = ShapeUtil::TupleElementCount(shape);
454 #ifndef NDEBUG
455 index->children_count = size;
456 #endif
457 node->is_leaf = false;
458 ShapeIndex shape_index = node->data.first;
459 shape_index.push_back(0);
460
461 // At the end of the index_table, reserve a continuous space to hold the
462 // children of current node. In order to enforce the invariant that all
463 // children of a given node are placed together, we need to do the
464 // reservation before we recurse into any of its children.
465 int64 children_start_position = index_table_.size();
466 index_table_.resize(index_table_.size() + size);
467
468 for (int i = 0; i < size; ++i) {
469 shape_index[shape_index.size() - 1] = i;
470 index_table_[children_start_position + i].index = nodes_.size();
471 // The first child of the node in the index table is placed at the end of
472 // the table.
473 index_table_[children_start_position + i].children_start =
474 index_table_.size();
475 nodes_.emplace_back(shape_index, init_value);
476 InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
477 &index_table_[children_start_position + i]);
478 }
479 } else {
480 #ifndef NDEBUG
481 index->children_count = 0;
482 #endif
483 }
484 }
485
486 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)487 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
488 if (shape.IsTuple()) {
489 const int64 size = ShapeUtil::TupleElementCount(shape);
490 #ifndef NDEBUG
491 index->children_count = size;
492 #endif
493 node->is_leaf = false;
494 ShapeIndex shape_index = node->data.first;
495 shape_index.push_back(0);
496
497 // At the end of the index_table, reserve a continuous space to hold the
498 // children of current node. In order to enforce the invariant that all
499 // children of a given node are placed together, we need to do the
500 // reservation before we recurse into any of its children.
501 int64 children_start_position = index_table_.size();
502 index_table_.resize(index_table_.size() + size);
503
504 for (int i = 0; i < size; ++i) {
505 shape_index[shape_index.size() - 1] = i;
506 index_table_[children_start_position + i].index = nodes_.size();
507 // The first child of the node in the index table is placed at the end of
508 // the table.
509 index_table_[children_start_position + i].children_start =
510 index_table_.size();
511 nodes_.emplace_back(shape_index);
512 InitChildren(shape.tuple_shapes(i), &nodes_.back(),
513 &index_table_[children_start_position + i]);
514 }
515 } else {
516 #ifndef NDEBUG
517 index->children_count = 0;
518 #endif
519 }
520 }
521
522 template <typename T>
ShapeTree(Shape shape)523 ShapeTree<T>::ShapeTree(Shape shape)
524 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
525 shape_(shape_storage_.get()) {
526 const int64 count = CountSubshapes(*shape_);
527 nodes_.reserve(count);
528 nodes_.emplace_back(ShapeIndex{});
529
530 index_table_.reserve(count);
531 index_table_.emplace_back(Index{0, 1});
532 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
533 }
534
535 template <typename T>
ShapeTree(const Shape * shape)536 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
537 const int64 count = CountSubshapes(*shape_);
538 nodes_.reserve(count);
539 nodes_.emplace_back(ShapeIndex{});
540
541 index_table_.reserve(count);
542 index_table_.emplace_back(Index{0, 1});
543 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
544 }
545
546 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)547 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
548 : shape_storage_(shape), shape_(shape_storage_.get()) {
549 const int64 count = CountSubshapes(*shape_);
550 nodes_.reserve(count);
551 nodes_.emplace_back(ShapeIndex{});
552
553 index_table_.reserve(count);
554 index_table_.emplace_back(Index{0, 1});
555 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
556 }
557
558 template <typename T>
ShapeTree(Shape shape,const T & init_value)559 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
560 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
561 shape_(shape_storage_.get()) {
562 const int64 count = CountSubshapes(*shape_);
563 nodes_.reserve(count);
564 nodes_.emplace_back(ShapeIndex{}, init_value);
565
566 index_table_.reserve(count);
567 index_table_.emplace_back(Index{0, 1});
568 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
569 }
570
571 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)572 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
573 : shape_(shape) {
574 const int64 count = CountSubshapes(*shape_);
575 nodes_.reserve(count);
576 nodes_.emplace_back(ShapeIndex{}, init_value);
577
578 index_table_.reserve(count);
579 index_table_.emplace_back(Index{0, 1});
580 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
581 }
582
583 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)584 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
585 const T& init_value)
586 : shape_storage_(shape), shape_(shape_storage_.get()) {
587 const int64 count = CountSubshapes(*shape_);
588 nodes_.reserve(count);
589 nodes_.emplace_back(ShapeIndex{}, init_value);
590
591 index_table_.reserve(count);
592 index_table_.emplace_back(Index{0, 1});
593 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
594 }
595
596 template <typename T>
element(ShapeIndexView index)597 const T& ShapeTree<T>::element(ShapeIndexView index) const {
598 return Lookup(index)->data.second;
599 }
600
601 template <typename T>
mutable_element(ShapeIndexView index)602 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
603 return &Lookup(index)->data.second;
604 }
605
606 template <typename T>
Lookup(ShapeIndexView index)607 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
608 Index* iter = &index_table_[0];
609 for (const int64 i : index) {
610 CHECK_GE(i, 0);
611 #ifndef NDEBUG
612 CHECK_LT(i, iter->children_count);
613 #endif
614 iter = &index_table_[iter->children_start + i];
615 }
616
617 return &nodes_[iter->index];
618 }
619
620 template <typename T>
Lookup(ShapeIndexView index)621 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
622 ShapeIndexView index) const {
623 return const_cast<ShapeTree*>(this)->Lookup(index);
624 }
625
626 /* static */
627 template <typename T>
628 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)629 Status ShapeTree<T>::ForEachHelper(const Fn& func,
630 const std::vector<Node>& nodes) {
631 for (const auto& node : nodes) {
632 TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
633 }
634 return Status::OK();
635 }
636
637 /* static */
638 template <typename T>
639 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)640 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
641 std::vector<Node>* nodes) {
642 for (auto& node : *nodes) {
643 TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
644 }
645 return Status::OK();
646 }
647
648 template <typename T>
649 template <typename Fn>
ForEachElementWithStatus(const Fn & func)650 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
651 return ForEachHelper(func, nodes_);
652 }
653
654 template <typename T>
655 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)656 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
657 return ForEachMutableHelper(func, &nodes_);
658 }
659
660 template <typename T>
661 template <typename Fn>
ForEachElement(const Fn & func)662 void ShapeTree<T>::ForEachElement(const Fn& func) const {
663 return ForEachHelper(
664 [&func](const ShapeIndex& index, const T& data) {
665 func(index, data);
666 return Status::OK();
667 },
668 nodes_)
669 .IgnoreError();
670 }
671
672 template <typename T>
673 template <typename Fn>
ForEachMutableElement(const Fn & func)674 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
675 return ForEachMutableHelper(
676 [&func](const ShapeIndex& index, T* data) {
677 func(index, data);
678 return Status::OK();
679 },
680 &nodes_)
681 .IgnoreError();
682 }
683
684 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)685 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
686 const ShapeIndex& source_base_index,
687 const ShapeIndex& target_base_index) {
688 CHECK(ShapeUtil::Compatible(
689 ShapeUtil::GetSubshape(shape(), target_base_index),
690 ShapeUtil::GetSubshape(other.shape(), source_base_index)))
691 << ShapeUtil::GetSubshape(shape(), target_base_index) << " vs "
692 << ShapeUtil::GetSubshape(other.shape(), source_base_index);
693 ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
694 const ShapeIndex& index, T* data) {
695 // Copy the data element only if index is in the
696 // subtree rooted at target_base_index.
697 for (int i = 0; i < target_base_index.size(); ++i) {
698 if (i >= index.size() || index[i] != target_base_index[i]) {
699 return;
700 }
701 }
702 // Construct source element index to copy from.
703 ShapeIndex source_index = source_base_index;
704 for (int i = target_base_index.size(); i < index.size(); ++i) {
705 source_index.push_back(index[i]);
706 }
707 *data = other.element(source_index);
708 });
709 }
710
711 template <typename T>
SubShapeTree(const ShapeIndex & index)712 StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
713 const ShapeIndex& index) const {
714 TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
715 ShapeUtil::TryGetSubshape(shape(), index));
716 ShapeTree<T> sub_shape_tree(*sub_shape);
717 sub_shape_tree.CopySubtreeFrom(*this, index, {});
718 return std::move(sub_shape_tree);
719 }
720
721 template <typename T>
722 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
723 bool equal = true;
724 ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
725 if (data != other.element(index)) {
726 equal = false;
727 }
728 });
729 return equal;
730 }
731
732 } // namespace xla
733
734 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
735