1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace xla {
38
39 namespace internal {
40
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44 // Data corresponding to this node.
45 std::pair<ShapeIndex, T> data;
46
47 bool is_leaf = true;
48
ShapeTreeNodeShapeTreeNode49 explicit ShapeTreeNode(ShapeIndex index)
50 : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51 ShapeTreeNode(ShapeIndex index, T data)
52 : data(std::move(index), std::move(data)) {}
53 };
54
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57 // Index of the node in the ShapeTreeNode vector.
58 uint32 index;
59 // Index of the first child in a IndexTableEntry vector. In the index
60 // table all children entries for a given node will be placed next to each
61 // other. This allows us to use a single field to index them.
62 uint32 children_start;
63 #ifndef NDEBUG
64 // Number of children, used for bounds checking.
65 uint32 children_count;
66 #endif
67 };
68
69 } // namespace internal
70
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73
74 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
75 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
76 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
77 // type T.
78 //
79 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
80 // ShapeTree is an identically structured tree with data elements of type T at
81 // every node. I.e. the root is a tuple by definition, all interior nodes are
82 // also tuples, and all leaves are arrays.
83 //
84 // Like the Shape data structure, this is a tree and tuple elements cannot be
85 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
86 // object.
87 //
88 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
89 // it's helpful not to copy a Shape just to make a ShapeTree. In these cases,
90 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's
91 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
92 // before its ShapeTree goes away.
93 template <typename T>
94 class ShapeTree {
95 public:
96 using Node = internal::ShapeTreeNode<T>;
97 using Index = internal::IndexTableEntry;
98
99 // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()100 ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
101
102 // Create ShapeTree with the given shape, and default-constructed T values for
103 // all nodes.
104 //
105 // The version that takes a pointer may be cheaper because it doesn't require
106 // any Shape copies, but then it's up to you to ensure that the pointer stays
107 // alive longer than this ShapeTree.
108 explicit ShapeTree(Shape shape);
109 explicit ShapeTree(const Shape* shape);
110 explicit ShapeTree(const std::shared_ptr<Shape>& shape);
111
112 // Create ShapeTree with the given shape, and init_value for all nodes.
113 ShapeTree(Shape shape, const T& init_value);
114 ShapeTree(const Shape* shape, const T& init_value);
115 ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
116
117 // Returns the data element associated with the array in the shape at the
118 // given index (see ShapeUtil::GetSubshape for how indexes are defined).
119 const T& element(ShapeIndexView index) const;
120 T* mutable_element(ShapeIndexView index);
121
122 // Return the shape represented with this ShapeTree.
shape()123 const Shape& shape() const { return *shape_; }
124
125 // A ShapeTree object can own the underlying Shape pointer (via the
126 // shape_storage_ member), or can point to a Shape object owned by the caller.
127 // This API replaces the underlying Shape object to the one supplied by the
128 // caller, whom must ensure the object remain valid for the whole lifetime of
129 // this ShapeTree object, and also that the Shape is consistent with it.
replace_shape_ptr(const Shape * shape)130 void replace_shape_ptr(const Shape* shape) {
131 if (shape_storage_ != nullptr) {
132 DCHECK_EQ(*shape, *shape_storage_);
133 shape_storage_ = nullptr;
134 }
135 shape_ = shape;
136 }
137
138 // Returns true if the node at the given index is a leaf node (an array
139 // shape).
IsLeaf(ShapeIndexView index)140 bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
141
142 ShapeTree(const ShapeTree&) = default;
143 ShapeTree& operator=(const ShapeTree&) = default;
144 ShapeTree(ShapeTree&&) = default;
145 ShapeTree& operator=(ShapeTree&& other) = default;
146
147 // iterator implements a bidirectional_iterator with
148 // value_type = std::pair<ShapeIndex, T>.
149 //
150 // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
151 using iterator =
152 ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
153 std::pair<ShapeIndex, T>>;
154 using const_iterator =
155 ShapeTreeIterator<const std::vector<Node>,
156 typename std::vector<Node>::const_iterator,
157 const std::pair<ShapeIndex, T>>;
158 using reverse_iterator = std::reverse_iterator<iterator>;
159 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
160
161 // begin/end for iterating over all nodes.
begin()162 iterator begin() {
163 return iterator(&nodes_, nodes_.begin(),
164 /*iterate_leaves_only=*/false);
165 }
end()166 iterator end() {
167 return iterator(&nodes_, nodes_.end(),
168 /*iterate_leaves_only=*/false);
169 }
begin()170 const_iterator begin() const {
171 return const_iterator(&nodes_, nodes_.begin(),
172 /*iterate_leaves_only=*/false);
173 }
end()174 const_iterator end() const {
175 return const_iterator(&nodes_, nodes_.end(),
176 /*iterate_leaves_only=*/false);
177 }
178
179 // rbegin/rend for iterating over all nodes in reverse.
rbegin()180 reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()181 reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()182 const_reverse_iterator rbegin() const {
183 return const_reverse_iterator(end());
184 }
rend()185 const_reverse_iterator rend() const {
186 return const_reverse_iterator(begin());
187 }
188
189 // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
190 // children).
leaf_begin()191 iterator leaf_begin() {
192 return iterator(&nodes_, nodes_.begin(),
193 /*iterate_leaves_only=*/true);
194 }
leaf_end()195 iterator leaf_end() {
196 return iterator(&nodes_, nodes_.end(),
197 /*iterate_leaves_only=*/true);
198 }
leaf_begin()199 const_iterator leaf_begin() const {
200 return const_iterator(&nodes_, nodes_.begin(),
201 /*iterate_leaves_only=*/true);
202 }
leaf_end()203 const_iterator leaf_end() const {
204 return const_iterator(&nodes_, nodes_.end(),
205 /*iterate_leaves_only=*/true);
206 }
207 // range-based iterator for leaf_begin()/leaf_end().
leaves()208 tensorflow::gtl::iterator_range<iterator> leaves() {
209 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
210 }
leaves()211 tensorflow::gtl::iterator_range<const_iterator> leaves() const {
212 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
213 }
214
leaf_rbegin()215 reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); }
leaf_rend()216 reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); }
leaf_rbegin()217 const_reverse_iterator leaf_rbegin() const {
218 return const_reverse_iterator(leaf_end());
219 }
leaf_rend()220 const_reverse_iterator leaf_rend() const {
221 return const_reverse_iterator(leaf_begin());
222 }
223
224 // Returns an iterator pointing to the given ShapeIndex.
225 // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)226 iterator find(ShapeIndexView index) {
227 Node* element = Lookup(index);
228 auto element_iter = nodes_.begin() + (element - &nodes_[0]);
229 return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
230 }
find(ShapeIndexView index)231 const_iterator find(ShapeIndexView index) const {
232 Node* element = Lookup(index);
233 auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
234 return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
235 }
236
237 // Returns the number of leaf nodes in the tree.
leaf_count()238 int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
239
240 // Recursively traverses the shape and calls the given function at each
241 // element. The function has the following arguments:
242 //
243 // Fn : A callable of type void(const ShapeIndex& index, const T& data)
244 // (or compatible).
245 // index : the index of the element in the shape. See ShapeUtil::GetSubshape
246 // for definition of index.
247 // data : The data value at this element.
248 template <typename Fn>
249 void ForEachElement(const Fn& func) const;
250
251 // Like ForEachElement, but the callable has type
252 //
253 // void (const ShapeIndex& index, T* data).
254 //
255 template <typename Fn>
256 void ForEachMutableElement(const Fn& func);
257
258 // Like ForEach(Mutable)Element, but the callable returns a Status instead of
259 // void. The first non-OK return value is returned by the ForEach* function.
260 template <typename Fn>
261 Status ForEachElementWithStatus(const Fn& func) const;
262 template <typename Fn>
263 Status ForEachMutableElementWithStatus(const Fn& func);
264
265 // Maps each element to generate a new tree with the same shape.
266 template <typename U>
Map(const std::function<U (const T &)> & func)267 ShapeTree<U> Map(const std::function<U(const T&)>& func) {
268 ShapeTree<U> result(shape_storage_);
269 ForEachElement([&](const ShapeIndex& index, const T& t) {
270 *result.mutable_element(index) = func(t);
271 });
272 return result;
273 }
274
275 template <typename U>
Map(const std::function<U (T *)> & func)276 ShapeTree<U> Map(const std::function<U(T*)>& func) {
277 ShapeTree<U> result(shape_storage_);
278 ForEachMutableElement([&](const ShapeIndex& index, T* t) {
279 *result.mutable_element(index) = func(t);
280 });
281 return result;
282 }
283
284 // Copy the subtree of values from 'other' rooted at ShapeIndex
285 // 'source_base_index' into the subtree of value in this ShapeTree rooted at
286 // 'target_base_index'.
287 //
288 // Precondition: The subshape of other.shape() at index source_base_index must
289 // be compatible with the subshape of shape() at index target_base_index.
290 void CopySubtreeFrom(const ShapeTree<T>& other,
291 const ShapeIndex& source_base_index,
292 const ShapeIndex& target_base_index);
293
294 StatusOr<ShapeTree<T>> SubShapeTree(const ShapeIndex& index) const;
295
296 bool operator==(const ShapeTree<T>& other) const;
297 bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
298
299 private:
300 // Initialize node->children based on 'shape'. All children are assigned the
301 // the given 'init_value'.
302 void InitChildren(const Shape& shape, const T& init_value, Node* node,
303 Index* index);
304
305 // Initialize node->children based on 'shape'. All children have
306 // default-constructed data values.
307 void InitChildren(const Shape& shape, Node* node, Index* index);
308
309 // Returns the number of subshapes, including interior nodes, in shape.
310 int64 CountSubshapes(const Shape& shape);
311
312 // Helpers for traversing the shape via ForEachElement. The helpers
313 // recursively traverse the subtree rooted at "index" (defined as in
314 // ShapeUtil::GetSubshape).
315 template <typename Fn>
316 static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
317 template <typename Fn>
318 static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
319
320 // Return the tree node at the given index.
321 Node* Lookup(ShapeIndexView index);
322 const Node* Lookup(ShapeIndexView index) const;
323
324 // The nodes in this shape tree.
325 std::vector<Node> nodes_;
326
327 // Index table for node lookups.
328 std::vector<Index> index_table_;
329
330 // If we own our Shape, this field contains it, and shape_ is a pointer into
331 // here. Otherwise if we don't own our shape, this is nullptr.
332 std::shared_ptr<Shape> shape_storage_;
333
334 // The XLA shape mirrored in this ShapeTree. This is either
335 // shape_storage_.get() or the Shape pointer passed to our constructor.
336 const Shape* shape_;
337 };
338
339 // Internal iterator that performs a pre-order walk. This is cheap to copy.
340 // The iterator value_type is equivalent to a
341 // std::pair<ShapeIndex,T>&, similar to std::map.
342 template <typename ContainerType, typename IteratorType, typename ValueType>
343 class ShapeTreeIterator
344 : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
345 public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node,bool iterate_leaves_only)346 ShapeTreeIterator(ContainerType* nodes, IteratorType node,
347 bool iterate_leaves_only)
348 : nodes_(nodes),
349 node_(std::move(node)),
350 iterate_leaves_only_(iterate_leaves_only) {
351 while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
352 ++node_;
353 }
354 }
355
356 ShapeTreeIterator& operator++() {
357 ++node_;
358 while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
359 ++node_;
360 }
361 return *this;
362 }
363 ShapeTreeIterator operator++(int) {
364 auto i = *this;
365 ++(*this);
366 return i;
367 }
368
369 ShapeTreeIterator& operator--() {
370 --node_;
371 while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
372 --node_;
373 }
374 return *this;
375 }
376 ShapeTreeIterator operator--(int) {
377 auto i = *this;
378 --(*this);
379 return i;
380 }
381
382 bool operator==(const ShapeTreeIterator& other) const {
383 return node_ == other.node_;
384 }
385 bool operator!=(const ShapeTreeIterator& other) const {
386 return node_ != other.node_;
387 }
388 ValueType& operator*() { return node_->data; }
389 ValueType* operator->() { return &node_->data; }
390
391 private:
392 ContainerType* nodes_;
393 IteratorType node_;
394 // True if we should not include interior nodes in our walk.
395 const bool iterate_leaves_only_;
396 };
397
398 template <typename T>
CountSubshapes(const Shape & shape)399 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
400 int64 current_count = 1;
401 if (shape.IsTuple()) {
402 int64 count = ShapeUtil::TupleElementCount(shape);
403 for (int i = 0; i < count; ++i) {
404 current_count += CountSubshapes(shape.tuple_shapes(i));
405 }
406 }
407 return current_count;
408 }
409
410 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)411 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
412 Node* node, Index* index) {
413 if (shape.IsTuple()) {
414 const int64 size = ShapeUtil::TupleElementCount(shape);
415 #ifndef NDEBUG
416 index->children_count = size;
417 #endif
418 node->is_leaf = false;
419 ShapeIndex shape_index = node->data.first;
420 shape_index.push_back(0);
421
422 // At the end of the index_table, reserve a continuous space to hold the
423 // children of current node. In order to enforce the invariant that all
424 // children of a given node are placed together, we need to do the
425 // reservation before we recurse into any of its children.
426 int64 children_start_position = index_table_.size();
427 index_table_.resize(index_table_.size() + size);
428
429 for (int i = 0; i < size; ++i) {
430 shape_index[shape_index.size() - 1] = i;
431 index_table_[children_start_position + i].index = nodes_.size();
432 // The first child of the node in the index table is placed at the end of
433 // the table.
434 index_table_[children_start_position + i].children_start =
435 index_table_.size();
436 nodes_.emplace_back(shape_index, init_value);
437 InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
438 &index_table_[children_start_position + i]);
439 }
440 } else {
441 #ifndef NDEBUG
442 index->children_count = 0;
443 #endif
444 }
445 }
446
447 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)448 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
449 if (shape.IsTuple()) {
450 const int64 size = ShapeUtil::TupleElementCount(shape);
451 #ifndef NDEBUG
452 index->children_count = size;
453 #endif
454 node->is_leaf = false;
455 ShapeIndex shape_index = node->data.first;
456 shape_index.push_back(0);
457
458 // At the end of the index_table, reserve a continuous space to hold the
459 // children of current node. In order to enforce the invariant that all
460 // children of a given node are placed together, we need to do the
461 // reservation before we recurse into any of its children.
462 int64 children_start_position = index_table_.size();
463 index_table_.resize(index_table_.size() + size);
464
465 for (int i = 0; i < size; ++i) {
466 shape_index[shape_index.size() - 1] = i;
467 index_table_[children_start_position + i].index = nodes_.size();
468 // The first child of the node in the index table is placed at the end of
469 // the table.
470 index_table_[children_start_position + i].children_start =
471 index_table_.size();
472 nodes_.emplace_back(shape_index);
473 InitChildren(shape.tuple_shapes(i), &nodes_.back(),
474 &index_table_[children_start_position + i]);
475 }
476 } else {
477 #ifndef NDEBUG
478 index->children_count = 0;
479 #endif
480 }
481 }
482
483 template <typename T>
ShapeTree(Shape shape)484 ShapeTree<T>::ShapeTree(Shape shape)
485 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
486 shape_(shape_storage_.get()) {
487 const int64 count = CountSubshapes(*shape_);
488 nodes_.reserve(count);
489 nodes_.emplace_back(ShapeIndex{});
490
491 index_table_.reserve(count);
492 index_table_.emplace_back(Index{0, 1});
493 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
494 }
495
496 template <typename T>
ShapeTree(const Shape * shape)497 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
498 const int64 count = CountSubshapes(*shape_);
499 nodes_.reserve(count);
500 nodes_.emplace_back(ShapeIndex{});
501
502 index_table_.reserve(count);
503 index_table_.emplace_back(Index{0, 1});
504 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
505 }
506
507 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)508 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
509 : shape_storage_(shape), shape_(shape_storage_.get()) {
510 const int64 count = CountSubshapes(*shape_);
511 nodes_.reserve(count);
512 nodes_.emplace_back(ShapeIndex{});
513
514 index_table_.reserve(count);
515 index_table_.emplace_back(Index{0, 1});
516 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
517 }
518
519 template <typename T>
ShapeTree(Shape shape,const T & init_value)520 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
521 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
522 shape_(shape_storage_.get()) {
523 const int64 count = CountSubshapes(*shape_);
524 nodes_.reserve(count);
525 nodes_.emplace_back(ShapeIndex{}, init_value);
526
527 index_table_.reserve(count);
528 index_table_.emplace_back(Index{0, 1});
529 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
530 }
531
532 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)533 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
534 : shape_(shape) {
535 const int64 count = CountSubshapes(*shape_);
536 nodes_.reserve(count);
537 nodes_.emplace_back(ShapeIndex{}, init_value);
538
539 index_table_.reserve(count);
540 index_table_.emplace_back(Index{0, 1});
541 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
542 }
543
544 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)545 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
546 const T& init_value)
547 : shape_storage_(shape), shape_(shape_storage_.get()) {
548 const int64 count = CountSubshapes(*shape_);
549 nodes_.reserve(count);
550 nodes_.emplace_back(ShapeIndex{}, init_value);
551
552 index_table_.reserve(count);
553 index_table_.emplace_back(Index{0, 1});
554 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
555 }
556
557 template <typename T>
element(ShapeIndexView index)558 const T& ShapeTree<T>::element(ShapeIndexView index) const {
559 return Lookup(index)->data.second;
560 }
561
562 template <typename T>
mutable_element(ShapeIndexView index)563 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
564 return &Lookup(index)->data.second;
565 }
566
567 template <typename T>
Lookup(ShapeIndexView index)568 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
569 Index* iter = &index_table_[0];
570 for (const int64 i : index) {
571 CHECK_GE(i, 0);
572 #ifndef NDEBUG
573 CHECK_LT(i, iter->children_count);
574 #endif
575 iter = &index_table_[iter->children_start + i];
576 }
577
578 return &nodes_[iter->index];
579 }
580
581 template <typename T>
Lookup(ShapeIndexView index)582 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
583 ShapeIndexView index) const {
584 return const_cast<ShapeTree*>(this)->Lookup(index);
585 }
586
587 /* static */
588 template <typename T>
589 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)590 Status ShapeTree<T>::ForEachHelper(const Fn& func,
591 const std::vector<Node>& nodes) {
592 for (const auto& node : nodes) {
593 TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
594 }
595 return Status::OK();
596 }
597
598 /* static */
599 template <typename T>
600 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)601 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
602 std::vector<Node>* nodes) {
603 for (auto& node : *nodes) {
604 TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
605 }
606 return Status::OK();
607 }
608
609 template <typename T>
610 template <typename Fn>
ForEachElementWithStatus(const Fn & func)611 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
612 return ForEachHelper(func, nodes_);
613 }
614
615 template <typename T>
616 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)617 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
618 return ForEachMutableHelper(func, &nodes_);
619 }
620
621 template <typename T>
622 template <typename Fn>
ForEachElement(const Fn & func)623 void ShapeTree<T>::ForEachElement(const Fn& func) const {
624 return ForEachHelper(
625 [&func](const ShapeIndex& index, const T& data) {
626 func(index, data);
627 return Status::OK();
628 },
629 nodes_)
630 .IgnoreError();
631 }
632
633 template <typename T>
634 template <typename Fn>
ForEachMutableElement(const Fn & func)635 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
636 return ForEachMutableHelper(
637 [&func](const ShapeIndex& index, T* data) {
638 func(index, data);
639 return Status::OK();
640 },
641 &nodes_)
642 .IgnoreError();
643 }
644
645 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)646 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
647 const ShapeIndex& source_base_index,
648 const ShapeIndex& target_base_index) {
649 CHECK(ShapeUtil::Compatible(
650 ShapeUtil::GetSubshape(shape(), target_base_index),
651 ShapeUtil::GetSubshape(other.shape(), source_base_index)));
652 ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
653 const ShapeIndex& index, T* data) {
654 // Copy the data element only if index is in the
655 // subtree rooted at target_base_index.
656 for (int i = 0; i < target_base_index.size(); ++i) {
657 if (i >= index.size() || index[i] != target_base_index[i]) {
658 return;
659 }
660 }
661 // Construct source element index to copy from.
662 ShapeIndex source_index = source_base_index;
663 for (int i = target_base_index.size(); i < index.size(); ++i) {
664 source_index.push_back(index[i]);
665 }
666 *data = other.element(source_index);
667 });
668 }
669
670 template <typename T>
SubShapeTree(const ShapeIndex & index)671 StatusOr<ShapeTree<T>> ShapeTree<T>::SubShapeTree(
672 const ShapeIndex& index) const {
673 TF_ASSIGN_OR_RETURN(const Shape* sub_shape,
674 ShapeUtil::TryGetSubshape(shape(), index));
675 ShapeTree<T> sub_shape_tree(*sub_shape);
676 sub_shape_tree.CopySubtreeFrom(*this, index, {});
677 return std::move(sub_shape_tree);
678 }
679
680 template <typename T>
681 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
682 bool equal = true;
683 ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
684 if (data != other.element(index)) {
685 equal = false;
686 }
687 });
688 return equal;
689 }
690
691 } // namespace xla
692
693 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
694