1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
17 #define TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
18
19 #include <functional>
20 #include <iterator>
21 #include <memory>
22 #include <vector>
23
24 #include "absl/memory/memory.h"
25 #include "absl/types/optional.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/layout_util.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 #include "tensorflow/compiler/xla/status_macros.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/lib/core/status.h"
33 #include "tensorflow/core/lib/gtl/iterator_range.h"
34 #include "tensorflow/core/platform/logging.h"
35 #include "tensorflow/core/platform/types.h"
36
37 namespace xla {
38
39 namespace internal {
40
41 // Internal representation of each node in a ShapeTree.
42 template <typename T>
43 struct ShapeTreeNode {
44 // Data corresponding to this node.
45 std::pair<ShapeIndex, T> data;
46
47 bool is_leaf = true;
48
ShapeTreeNodeShapeTreeNode49 explicit ShapeTreeNode(ShapeIndex index)
50 : ShapeTreeNode(std::move(index), T()) {}
ShapeTreeNodeShapeTreeNode51 ShapeTreeNode(ShapeIndex index, T data)
52 : data(std::move(index), std::move(data)) {}
53 };
54
55 // Internal representation of an index table entry.
56 struct IndexTableEntry {
57 // Index of the node in the ShapeTreeNode vector.
58 uint32 index;
59 // Index of the first child in a IndexTableEntry vector. In the index
60 // table all children entries for a given node will be placed next to each
61 // other. This allows us to use a single field to index them.
62 uint32 children_start;
63 #ifndef NDEBUG
64 // Number of children, used for bounds checking.
65 uint32 children_count;
66 #endif
67 };
68
69 } // namespace internal
70
71 template <typename ContainerType, typename IteratorType, typename ValueType>
72 class ShapeTreeIterator;
73
74 // A ShapeTree<T> is a recursive data structure which mirrors the structure of a
75 // XLA shape and holds a value of type T for each subshape (i.e. tuple or array)
76 // in the shape. For array shapes, a ShapeTree trivially holds a single value of
77 // type T.
78 //
79 // For tuple shapes which can be an arbitrary tree with arrays at the leaves, a
80 // ShapeTree is an identically structured tree with data elements of type T at
81 // every node. I.e. the root is a tuple by definition, all interior nodes are
82 // also tuples, and all leaves are arrays.
83 //
84 // Like the Shape data structure, this is a tree and tuple elements cannot be
85 // duplicated. That is, every distinct ShapeIndex in the Shape has a unique T
86 // object.
87 //
88 // Normally a ShapeTree owns its Shape, but for efficiency reasons, sometimes
89 // it's helpful not to copy a Shape just to make a ShapeTree. In these cases,
90 // you can pass a Shape* instead of a Shape& to the ShapeTree constructor. It's
91 // then up to you to ensure that the pointed-to Shape doesn't die or mutate
92 // before its ShapeTree goes away.
93 template <typename T>
94 class ShapeTree {
95 public:
96 using Node = internal::ShapeTreeNode<T>;
97 using Index = internal::IndexTableEntry;
98
99 // Default constructor creates a tree with a nil shape (i.e. an empty tuple).
ShapeTree()100 ShapeTree() : ShapeTree(ShapeUtil::MakeNil()) {}
101
102 // Create ShapeTree with the given shape, and default-constructed T values for
103 // all nodes.
104 //
105 // The version that takes a pointer may be cheaper because it doesn't require
106 // any Shape copies, but then it's up to you to ensure that the pointer stays
107 // alive longer than this ShapeTree.
108 explicit ShapeTree(Shape shape);
109 explicit ShapeTree(const Shape* shape);
110 explicit ShapeTree(const std::shared_ptr<Shape>& shape);
111
112 // Create ShapeTree with the given shape, and init_value for all nodes.
113 ShapeTree(Shape shape, const T& init_value);
114 ShapeTree(const Shape* shape, const T& init_value);
115 ShapeTree(const std::shared_ptr<Shape>& shape, const T& init_value);
116
117 // Returns the data element associated with the array in the shape at the
118 // given index (see ShapeUtil::GetSubshape for how indexes are defined).
119 const T& element(ShapeIndexView index) const;
120 T* mutable_element(ShapeIndexView index);
121
122 // Return the shape represented with this ShapeTree.
shape()123 const Shape& shape() const { return *shape_; }
124
125 // Replaces *only* the underlying shape of this ShapeTree. The caller must own
126 // the Shape object and hence shape_storage_ is not updated.
127 //
128 // Only safe to use this if the ShapeTree was constructed with 'explicit
129 // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The
130 // caller must ensure that the input shape is consistent with the underlying
131 // tree.
replace_shape_ptr(const Shape * shape)132 void replace_shape_ptr(const Shape* shape) {
133 CHECK(shape_storage_.get() == nullptr);
134 shape_ = shape;
135 }
136
137 // Returns true if the node at the given index is a leaf node (an array
138 // shape).
IsLeaf(ShapeIndexView index)139 bool IsLeaf(ShapeIndexView index) const { return Lookup(index)->is_leaf; }
140
141 ShapeTree(const ShapeTree&) = default;
142 ShapeTree& operator=(const ShapeTree&) = default;
143 ShapeTree(ShapeTree&&) = default;
144 ShapeTree& operator=(ShapeTree&& other) = default;
145
146 // iterator implements a bidirectional_iterator with
147 // value_type = std::pair<ShapeIndex, T>.
148 //
149 // The iteration order is guaranteed to be a pre-order walk of the ShapeTree.
150 using iterator =
151 ShapeTreeIterator<std::vector<Node>, typename std::vector<Node>::iterator,
152 std::pair<ShapeIndex, T>>;
153 using const_iterator =
154 ShapeTreeIterator<const std::vector<Node>,
155 typename std::vector<Node>::const_iterator,
156 const std::pair<ShapeIndex, T>>;
157 using reverse_iterator = std::reverse_iterator<iterator>;
158 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
159
160 // begin/end for iterating over all nodes.
begin()161 iterator begin() {
162 return iterator(&nodes_, nodes_.begin(),
163 /*iterate_leaves_only=*/false);
164 }
end()165 iterator end() {
166 return iterator(&nodes_, nodes_.end(),
167 /*iterate_leaves_only=*/false);
168 }
begin()169 const_iterator begin() const {
170 return const_iterator(&nodes_, nodes_.begin(),
171 /*iterate_leaves_only=*/false);
172 }
end()173 const_iterator end() const {
174 return const_iterator(&nodes_, nodes_.end(),
175 /*iterate_leaves_only=*/false);
176 }
177
178 // rbegin/rend for iterating over all nodes in reverse.
rbegin()179 reverse_iterator rbegin() { return reverse_iterator(end()); }
rend()180 reverse_iterator rend() { return reverse_iterator(begin()); }
rbegin()181 const_reverse_iterator rbegin() const {
182 return const_reverse_iterator(end());
183 }
rend()184 const_reverse_iterator rend() const {
185 return const_reverse_iterator(begin());
186 }
187
188 // leaf_begin()/leaf_end() iterates over all leaf nodes (nodes with no
189 // children).
leaf_begin()190 iterator leaf_begin() {
191 return iterator(&nodes_, nodes_.begin(),
192 /*iterate_leaves_only=*/true);
193 }
leaf_end()194 iterator leaf_end() {
195 return iterator(&nodes_, nodes_.end(),
196 /*iterate_leaves_only=*/true);
197 }
leaf_begin()198 const_iterator leaf_begin() const {
199 return const_iterator(&nodes_, nodes_.begin(),
200 /*iterate_leaves_only=*/true);
201 }
leaf_end()202 const_iterator leaf_end() const {
203 return const_iterator(&nodes_, nodes_.end(),
204 /*iterate_leaves_only=*/true);
205 }
206 // range-based iterator for leaf_begin()/leaf_end().
leaves()207 tensorflow::gtl::iterator_range<iterator> leaves() {
208 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
209 }
leaves()210 tensorflow::gtl::iterator_range<const_iterator> leaves() const {
211 return tensorflow::gtl::make_range(leaf_begin(), leaf_end());
212 }
213
leaf_rbegin()214 reverse_iterator leaf_rbegin() { return reverse_iterator(leaf_end()); }
leaf_rend()215 reverse_iterator leaf_rend() { return reverse_iterator(leaf_begin()); }
leaf_rbegin()216 const_reverse_iterator leaf_rbegin() const {
217 return const_reverse_iterator(leaf_end());
218 }
leaf_rend()219 const_reverse_iterator leaf_rend() const {
220 return const_reverse_iterator(leaf_begin());
221 }
222
223 // Returns an iterator pointing to the given ShapeIndex.
224 // REQUIRES: index must exist in the ShapeTree.
find(ShapeIndexView index)225 iterator find(ShapeIndexView index) {
226 Node* element = Lookup(index);
227 auto element_iter = nodes_.begin() + (element - &nodes_[0]);
228 return iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
229 }
find(ShapeIndexView index)230 const_iterator find(ShapeIndexView index) const {
231 Node* element = Lookup(index);
232 auto element_iter = nodes_.cbegin() + (element - &nodes_[0]);
233 return const_iterator(&nodes_, element_iter, /*iterate_leaves_only=*/false);
234 }
235
236 // Returns the number of leaf nodes in the tree.
leaf_count()237 int64 leaf_count() const { return std::distance(leaf_begin(), leaf_end()); }
238
239 // Recursively traverses the shape and calls the given function at each
240 // element. The function has the following arguments:
241 //
242 // Fn : A callable of type void(const ShapeIndex& index, const T& data)
243 // (or compatible).
244 // index : the index of the element in the shape. See ShapeUtil::GetSubshape
245 // for definition of index.
246 // data : The data value at this element.
247 template <typename Fn>
248 void ForEachElement(const Fn& func) const;
249
250 // Like ForEachElement, but the callable has type
251 //
252 // void (const ShapeIndex& index, T* data).
253 //
254 template <typename Fn>
255 void ForEachMutableElement(const Fn& func);
256
257 // Like ForEach(Mutable)Element, but the callable returns a Status instead of
258 // void. The first non-OK return value is returned by the ForEach* function.
259 template <typename Fn>
260 Status ForEachElementWithStatus(const Fn& func) const;
261 template <typename Fn>
262 Status ForEachMutableElementWithStatus(const Fn& func);
263
264 // Maps each element to generate a new tree with the same shape.
265 template <typename U>
Map(const std::function<U (const T &)> & func)266 ShapeTree<U> Map(const std::function<U(const T&)>& func) {
267 ShapeTree<U> result(shape_storage_);
268 ForEachElement([&](const ShapeIndex& index, const T& t) {
269 *result.mutable_element(index) = func(t);
270 });
271 return result;
272 }
273
274 template <typename U>
Map(const std::function<U (T *)> & func)275 ShapeTree<U> Map(const std::function<U(T*)>& func) {
276 ShapeTree<U> result(shape_storage_);
277 ForEachMutableElement([&](const ShapeIndex& index, T* t) {
278 *result.mutable_element(index) = func(t);
279 });
280 return result;
281 }
282
283 // Copy the subtree of values from 'other' rooted at ShapeIndex
284 // 'source_base_index' into the subtree of value in this ShapeTree rooted at
285 // 'target_base_index'.
286 //
287 // Precondition: The subshape of other.shape() at index source_base_index must
288 // be compatible with the subshape of shape() at index target_base_index.
289 void CopySubtreeFrom(const ShapeTree<T>& other,
290 const ShapeIndex& source_base_index,
291 const ShapeIndex& target_base_index);
292
293 bool operator==(const ShapeTree<T>& other) const;
294 bool operator!=(const ShapeTree<T>& other) const { return !(*this == other); }
295
296 private:
297 // Initialize node->children based on 'shape'. All children are assigned the
298 // the given 'init_value'.
299 void InitChildren(const Shape& shape, const T& init_value, Node* node,
300 Index* index);
301
302 // Initialize node->children based on 'shape'. All children have
303 // default-constructed data values.
304 void InitChildren(const Shape& shape, Node* node, Index* index);
305
306 // Returns the number of subshapes, including interior nodes, in shape.
307 int64 CountSubshapes(const Shape& shape);
308
309 // Helpers for traversing the shape via ForEachElement. The helpers
310 // recursively traverse the subtree rooted at "index" (defined as in
311 // ShapeUtil::GetSubshape).
312 template <typename Fn>
313 static Status ForEachHelper(const Fn& func, const std::vector<Node>& nodes);
314 template <typename Fn>
315 static Status ForEachMutableHelper(const Fn& func, std::vector<Node>* nodes);
316
317 // Return the tree node at the given index.
318 Node* Lookup(ShapeIndexView index);
319 const Node* Lookup(ShapeIndexView index) const;
320
321 // The nodes in this shape tree.
322 std::vector<Node> nodes_;
323
324 // Index table for node lookups.
325 std::vector<Index> index_table_;
326
327 // If we own our Shape, this field contains it, and shape_ is a pointer into
328 // here. Otherwise if we don't own our shape, this is nullptr.
329 std::shared_ptr<Shape> shape_storage_;
330
331 // The XLA shape mirrored in this ShapeTree. This is either
332 // shape_storage_.get() or the Shape pointer passed to our constructor.
333 const Shape* shape_;
334 };
335
336 // Internal iterator that performs a pre-order walk. This is cheap to copy.
337 // The iterator value_type is equivalent to a
338 // std::pair<ShapeIndex,T>&, similar to std::map.
339 template <typename ContainerType, typename IteratorType, typename ValueType>
340 class ShapeTreeIterator
341 : public std::iterator<std::bidirectional_iterator_tag, ValueType> {
342 public:
ShapeTreeIterator(ContainerType * nodes,IteratorType node,bool iterate_leaves_only)343 ShapeTreeIterator(ContainerType* nodes, IteratorType node,
344 bool iterate_leaves_only)
345 : nodes_(nodes),
346 node_(std::move(node)),
347 iterate_leaves_only_(iterate_leaves_only) {
348 while (iterate_leaves_only && node_ != nodes_->end() && !node_->is_leaf) {
349 ++node_;
350 }
351 }
352
353 ShapeTreeIterator& operator++() {
354 ++node_;
355 while (iterate_leaves_only_ && node_ != nodes_->end() && !node_->is_leaf) {
356 ++node_;
357 }
358 return *this;
359 }
360 ShapeTreeIterator operator++(int) {
361 auto i = *this;
362 ++(*this);
363 return i;
364 }
365
366 ShapeTreeIterator& operator--() {
367 --node_;
368 while (iterate_leaves_only_ && node_ > nodes_->begin() && !node_->is_leaf) {
369 --node_;
370 }
371 return *this;
372 }
373 ShapeTreeIterator operator--(int) {
374 auto i = *this;
375 --(*this);
376 return i;
377 }
378
379 bool operator==(const ShapeTreeIterator& other) const {
380 return node_ == other.node_;
381 }
382 bool operator!=(const ShapeTreeIterator& other) const {
383 return node_ != other.node_;
384 }
385 ValueType& operator*() { return node_->data; }
386 ValueType* operator->() { return &node_->data; }
387
388 private:
389 ContainerType* nodes_;
390 IteratorType node_;
391 // True if we should not include interior nodes in our walk.
392 const bool iterate_leaves_only_;
393 };
394
395 template <typename T>
CountSubshapes(const Shape & shape)396 int64 ShapeTree<T>::CountSubshapes(const Shape& shape) {
397 int64 current_count = 1;
398 if (shape.IsTuple()) {
399 int64 count = ShapeUtil::TupleElementCount(shape);
400 for (int i = 0; i < count; ++i) {
401 current_count += CountSubshapes(shape.tuple_shapes(i));
402 }
403 }
404 return current_count;
405 }
406
407 template <typename T>
InitChildren(const Shape & shape,const T & init_value,Node * node,Index * index)408 void ShapeTree<T>::InitChildren(const Shape& shape, const T& init_value,
409 Node* node, Index* index) {
410 if (shape.IsTuple()) {
411 const int64 size = ShapeUtil::TupleElementCount(shape);
412 #ifndef NDEBUG
413 index->children_count = size;
414 #endif
415 node->is_leaf = false;
416 ShapeIndex shape_index = node->data.first;
417 shape_index.push_back(0);
418
419 // At the end of the index_table, reserve a continuous space to hold the
420 // children of current node. In order to enforce the invariant that all
421 // children of a given node are placed together, we need to do the
422 // reservation before we recurse into any of its children.
423 int64 children_start_position = index_table_.size();
424 index_table_.resize(index_table_.size() + size);
425
426 for (int i = 0; i < size; ++i) {
427 shape_index[shape_index.size() - 1] = i;
428 index_table_[children_start_position + i].index = nodes_.size();
429 // The first child of the node in the index table is placed at the end of
430 // the table.
431 index_table_[children_start_position + i].children_start =
432 index_table_.size();
433 nodes_.emplace_back(shape_index, init_value);
434 InitChildren(shape.tuple_shapes(i), init_value, &nodes_.back(),
435 &index_table_[children_start_position + i]);
436 }
437 } else {
438 #ifndef NDEBUG
439 index->children_count = 0;
440 #endif
441 }
442 }
443
444 template <typename T>
InitChildren(const Shape & shape,Node * node,Index * index)445 void ShapeTree<T>::InitChildren(const Shape& shape, Node* node, Index* index) {
446 if (shape.IsTuple()) {
447 const int64 size = ShapeUtil::TupleElementCount(shape);
448 #ifndef NDEBUG
449 index->children_count = size;
450 #endif
451 node->is_leaf = false;
452 ShapeIndex shape_index = node->data.first;
453 shape_index.push_back(0);
454
455 // At the end of the index_table, reserve a continuous space to hold the
456 // children of current node. In order to enforce the invariant that all
457 // children of a given node are placed together, we need to do the
458 // reservation before we recurse into any of its children.
459 int64 children_start_position = index_table_.size();
460 index_table_.resize(index_table_.size() + size);
461
462 for (int i = 0; i < size; ++i) {
463 shape_index[shape_index.size() - 1] = i;
464 index_table_[children_start_position + i].index = nodes_.size();
465 // The first child of the node in the index table is placed at the end of
466 // the table.
467 index_table_[children_start_position + i].children_start =
468 index_table_.size();
469 nodes_.emplace_back(shape_index);
470 InitChildren(shape.tuple_shapes(i), &nodes_.back(),
471 &index_table_[children_start_position + i]);
472 }
473 } else {
474 #ifndef NDEBUG
475 index->children_count = 0;
476 #endif
477 }
478 }
479
480 template <typename T>
ShapeTree(Shape shape)481 ShapeTree<T>::ShapeTree(Shape shape)
482 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
483 shape_(shape_storage_.get()) {
484 const int64 count = CountSubshapes(*shape_);
485 nodes_.reserve(count);
486 nodes_.emplace_back(ShapeIndex{});
487
488 index_table_.reserve(count);
489 index_table_.emplace_back(Index{0, 1});
490 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
491 }
492
493 template <typename T>
ShapeTree(const Shape * shape)494 ShapeTree<T>::ShapeTree(const Shape* shape) : shape_(shape) {
495 const int64 count = CountSubshapes(*shape_);
496 nodes_.reserve(count);
497 nodes_.emplace_back(ShapeIndex{});
498
499 index_table_.reserve(count);
500 index_table_.emplace_back(Index{0, 1});
501 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
502 }
503
504 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape)505 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape)
506 : shape_storage_(shape), shape_(shape_storage_.get()) {
507 const int64 count = CountSubshapes(*shape_);
508 nodes_.reserve(count);
509 nodes_.emplace_back(ShapeIndex{});
510
511 index_table_.reserve(count);
512 index_table_.emplace_back(Index{0, 1});
513 InitChildren(*shape_, &nodes_[0], &index_table_[0]);
514 }
515
516 template <typename T>
ShapeTree(Shape shape,const T & init_value)517 ShapeTree<T>::ShapeTree(Shape shape, const T& init_value)
518 : shape_storage_(std::make_shared<Shape>(std::move(shape))),
519 shape_(shape_storage_.get()) {
520 const int64 count = CountSubshapes(*shape_);
521 nodes_.reserve(count);
522 nodes_.emplace_back(ShapeIndex{}, init_value);
523
524 index_table_.reserve(count);
525 index_table_.emplace_back(Index{0, 1});
526 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
527 }
528
529 template <typename T>
ShapeTree(const Shape * shape,const T & init_value)530 ShapeTree<T>::ShapeTree(const Shape* shape, const T& init_value)
531 : shape_(shape) {
532 const int64 count = CountSubshapes(*shape_);
533 nodes_.reserve(count);
534 nodes_.emplace_back(ShapeIndex{}, init_value);
535
536 index_table_.reserve(count);
537 index_table_.emplace_back(Index{0, 1});
538 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
539 }
540
541 template <typename T>
ShapeTree(const std::shared_ptr<Shape> & shape,const T & init_value)542 ShapeTree<T>::ShapeTree(const std::shared_ptr<Shape>& shape,
543 const T& init_value)
544 : shape_storage_(shape), shape_(shape_storage_.get()) {
545 const int64 count = CountSubshapes(*shape_);
546 nodes_.reserve(count);
547 nodes_.emplace_back(ShapeIndex{}, init_value);
548
549 index_table_.reserve(count);
550 index_table_.emplace_back(Index{0, 1});
551 InitChildren(*shape_, init_value, &nodes_[0], &index_table_[0]);
552 }
553
554 template <typename T>
element(ShapeIndexView index)555 const T& ShapeTree<T>::element(ShapeIndexView index) const {
556 return Lookup(index)->data.second;
557 }
558
559 template <typename T>
mutable_element(ShapeIndexView index)560 T* ShapeTree<T>::mutable_element(ShapeIndexView index) {
561 return &Lookup(index)->data.second;
562 }
563
564 template <typename T>
Lookup(ShapeIndexView index)565 internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(ShapeIndexView index) {
566 Index* iter = &index_table_[0];
567 for (const int64 i : index) {
568 CHECK_GE(i, 0);
569 #ifndef NDEBUG
570 CHECK_LT(i, iter->children_count);
571 #endif
572 iter = &index_table_[iter->children_start + i];
573 }
574
575 return &nodes_[iter->index];
576 }
577
578 template <typename T>
Lookup(ShapeIndexView index)579 const internal::ShapeTreeNode<T>* ShapeTree<T>::Lookup(
580 ShapeIndexView index) const {
581 return const_cast<ShapeTree*>(this)->Lookup(index);
582 }
583
584 /* static */
585 template <typename T>
586 template <typename Fn>
ForEachHelper(const Fn & func,const std::vector<Node> & nodes)587 Status ShapeTree<T>::ForEachHelper(const Fn& func,
588 const std::vector<Node>& nodes) {
589 for (const auto& node : nodes) {
590 TF_RETURN_IF_ERROR(func(node.data.first, node.data.second));
591 }
592 return Status::OK();
593 }
594
595 /* static */
596 template <typename T>
597 template <typename Fn>
ForEachMutableHelper(const Fn & func,std::vector<Node> * nodes)598 Status ShapeTree<T>::ForEachMutableHelper(const Fn& func,
599 std::vector<Node>* nodes) {
600 for (auto& node : *nodes) {
601 TF_RETURN_IF_ERROR(func(node.data.first, &node.data.second));
602 }
603 return Status::OK();
604 }
605
606 template <typename T>
607 template <typename Fn>
ForEachElementWithStatus(const Fn & func)608 Status ShapeTree<T>::ForEachElementWithStatus(const Fn& func) const {
609 return ForEachHelper(func, nodes_);
610 }
611
612 template <typename T>
613 template <typename Fn>
ForEachMutableElementWithStatus(const Fn & func)614 Status ShapeTree<T>::ForEachMutableElementWithStatus(const Fn& func) {
615 return ForEachMutableHelper(func, &nodes_);
616 }
617
618 template <typename T>
619 template <typename Fn>
ForEachElement(const Fn & func)620 void ShapeTree<T>::ForEachElement(const Fn& func) const {
621 return ForEachHelper(
622 [&func](const ShapeIndex& index, const T& data) {
623 func(index, data);
624 return Status::OK();
625 },
626 nodes_)
627 .IgnoreError();
628 }
629
630 template <typename T>
631 template <typename Fn>
ForEachMutableElement(const Fn & func)632 void ShapeTree<T>::ForEachMutableElement(const Fn& func) {
633 return ForEachMutableHelper(
634 [&func](const ShapeIndex& index, T* data) {
635 func(index, data);
636 return Status::OK();
637 },
638 &nodes_)
639 .IgnoreError();
640 }
641
642 template <typename T>
CopySubtreeFrom(const ShapeTree<T> & other,const ShapeIndex & source_base_index,const ShapeIndex & target_base_index)643 void ShapeTree<T>::CopySubtreeFrom(const ShapeTree<T>& other,
644 const ShapeIndex& source_base_index,
645 const ShapeIndex& target_base_index) {
646 CHECK(ShapeUtil::Compatible(
647 ShapeUtil::GetSubshape(shape(), target_base_index),
648 ShapeUtil::GetSubshape(other.shape(), source_base_index)));
649 ForEachMutableElement([this, &other, &source_base_index, &target_base_index](
650 const ShapeIndex& index, T* data) {
651 // Copy the data element only if index is in the
652 // subtree rooted at target_base_index.
653 for (int i = 0; i < target_base_index.size(); ++i) {
654 if (i >= index.size() || index[i] != target_base_index[i]) {
655 return;
656 }
657 }
658 // Construct source element index to copy from.
659 ShapeIndex source_index = source_base_index;
660 for (int i = target_base_index.size(); i < index.size(); ++i) {
661 source_index.push_back(index[i]);
662 }
663 *data = other.element(source_index);
664 });
665 }
666
667 template <typename T>
668 bool ShapeTree<T>::operator==(const ShapeTree<T>& other) const {
669 bool equal = true;
670 ForEachElement([&other, &equal](const ShapeIndex& index, const T& data) {
671 if (data != other.element(index)) {
672 equal = false;
673 }
674 });
675 return equal;
676 }
677
678 } // namespace xla
679
680 #endif // TENSORFLOW_COMPILER_XLA_SHAPE_TREE_H_
681